Passed
Pull Request — master (#115)
by Fernando
01:25
created

example_heteromodal   A

Complexity

Total Complexity 0

Size/Duplication

Total Lines 74
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 0
eloc 43
dl 0
loc 74
rs 10
c 0
b 0
f 0
1
"""
2
This is an example of a very particular case in which some modalities might be
3
missing for some of the subjects, as in
4
5
    Dorent et al. 2019, Hetero-Modal Variational Encoder-Decoder
6
    for Joint Modality Completion and Segmentation
7
8
"""
9
10
import torch.nn as nn
11
from torch.utils.data import DataLoader
12
13
import torchio
14
from torchio import Image, Subject, ImagesDataset, Queue
15
from torchio.data import ImageSampler
16
from torchio.transforms import (
17
    ZNormalization,
18
    RandomNoise,
19
    RandomFlip,
20
    RandomAffine,
21
)
22
23
# Define training and patches sampling parameters
24
num_epochs = 4
25
patch_size = 128
26
queue_length = 100
27
samples_per_volume = 1
28
batch_size = 2
29
30
# Populate a list with images
31
one_subject = Subject(
32
    T1=Image(../BRATS2018_crop_renamed/LGG75_T1.nii.gz', torchio.INTENSITY),
33
    T2=Image('../BRATS2018_crop_renamed/LGG75_T2.nii.gz', torchio.INTENSITY),
34
    label=Image('../BRATS2018_crop_renamed/LGG75_Label.nii.gz', torchio.LABEL),
35
)
36
37
# This subject doesn't have a T2 MRI!
38
another_subject = Subject(
39
    T1=Image(../BRATS2018_crop_renamed/LGG74_T1.nii.gz', torchio.INTENSITY),
40
    label=Image('../BRATS2018_crop_renamed/LGG74_Label.nii.gz', torchio.LABEL),
41
)
42
43
subjects = [
44
    one_subject,
45
    another_subject,
46
]
47
48
subjects_dataset = ImagesDataset(subjects)
49
queue_dataset = Queue(
50
    subjects_dataset,
51
    queue_length,
52
    samples_per_volume,
53
    patch_size,
54
    ImageSampler,
55
)
56
57
# This collate_fn is needed in the case of missing modalities
58
# In this case, the batch will be composed by a *list* of samples instead of
59
# the typical Python dictionary that is collated by default in Pytorch
60
batch_loader = DataLoader(
61
    queue_dataset,
62
    batch_size=batch_size,
63
    collate_fn=lambda x: x,
64
)
65
66
# Mock PyTorch model
67
model = nn.Module()
68
69
for epoch_index in range(num_epochs):
70
    for batch in batch_loader:  # batch is a *list* here, not a dictionary
71
        logits = model(batch)
72
        print([batch[idx].keys() for idx in range(batch_size)])
73
print()
74