|
1
|
|
|
""" |
|
2
|
|
|
This is an example of a very particular case in which some modalities might be |
|
3
|
|
|
missing for some of the subjects, as in |
|
4
|
|
|
|
|
5
|
|
|
Dorent et al. 2019, Hetero-Modal Variational Encoder-Decoder |
|
6
|
|
|
for Joint Modality Completion and Segmentation |
|
7
|
|
|
|
|
8
|
|
|
""" |
|
9
|
|
|
|
|
10
|
|
|
import torch.nn as nn |
|
11
|
|
|
from torch.utils.data import DataLoader |
|
12
|
|
|
|
|
13
|
|
|
import torchio |
|
14
|
|
|
from torchio import Image, Subject, ImagesDataset, Queue |
|
15
|
|
|
from torchio.data import ImageSampler |
|
16
|
|
|
|
|
17
|
|
|
def main(): |
|
18
|
|
|
# Define training and patches sampling parameters |
|
19
|
|
|
num_epochs = 20 |
|
20
|
|
|
patch_size = 128 |
|
21
|
|
|
queue_length = 100 |
|
22
|
|
|
samples_per_volume = 5 |
|
23
|
|
|
batch_size = 2 |
|
24
|
|
|
|
|
25
|
|
|
# Populate a list with images |
|
26
|
|
|
one_subject = Subject( |
|
27
|
|
|
T1=Image('../BRATS2018_crop_renamed/LGG75_T1.nii.gz', torchio.INTENSITY), |
|
28
|
|
|
T2=Image('../BRATS2018_crop_renamed/LGG75_T2.nii.gz', torchio.INTENSITY), |
|
29
|
|
|
label=Image('../BRATS2018_crop_renamed/LGG75_Label.nii.gz', torchio.LABEL), |
|
30
|
|
|
) |
|
31
|
|
|
|
|
32
|
|
|
# This subject doesn't have a T2 MRI! |
|
33
|
|
|
another_subject = Subject( |
|
34
|
|
|
T1=Image('../BRATS2018_crop_renamed/LGG74_T1.nii.gz', torchio.INTENSITY), |
|
35
|
|
|
label=Image('../BRATS2018_crop_renamed/LGG74_Label.nii.gz', torchio.LABEL), |
|
36
|
|
|
) |
|
37
|
|
|
|
|
38
|
|
|
subjects = [ |
|
39
|
|
|
one_subject, |
|
40
|
|
|
another_subject, |
|
41
|
|
|
] |
|
42
|
|
|
|
|
43
|
|
|
subjects_dataset = ImagesDataset(subjects) |
|
44
|
|
|
queue_dataset = Queue( |
|
45
|
|
|
subjects_dataset, |
|
46
|
|
|
queue_length, |
|
47
|
|
|
samples_per_volume, |
|
48
|
|
|
patch_size, |
|
49
|
|
|
ImageSampler, |
|
50
|
|
|
) |
|
51
|
|
|
|
|
52
|
|
|
# This collate_fn is needed in the case of missing modalities |
|
53
|
|
|
# In this case, the batch will be composed by a *list* of samples instead of |
|
54
|
|
|
# the typical Python dictionary that is collated by default in Pytorch |
|
55
|
|
|
batch_loader = DataLoader( |
|
56
|
|
|
queue_dataset, |
|
57
|
|
|
batch_size=batch_size, |
|
58
|
|
|
collate_fn=lambda x: x, |
|
59
|
|
|
) |
|
60
|
|
|
|
|
61
|
|
|
# Mock PyTorch model |
|
62
|
|
|
model = nn.Identity() |
|
63
|
|
|
|
|
64
|
|
|
for epoch_index in range(num_epochs): |
|
65
|
|
|
for batch in batch_loader: # batch is a *list* here, not a dictionary |
|
66
|
|
|
logits = model(batch) |
|
67
|
|
|
print([batch[idx].keys() for idx in range(batch_size)]) |
|
68
|
|
|
print() |
|
69
|
|
|
|
|
70
|
|
|
|
|
71
|
|
|
if __name__ == "__main__": |
|
72
|
|
|
main() |
|
73
|
|
|
|