Passed
Push — master ( 3ec417...138f68 )
by Fernando
01:26
created

example_heteromodal   A

Complexity

Total Complexity 4

Size/Duplication

Total Lines 73
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 41
dl 0
loc 73
rs 10
c 0
b 0
f 0
wmc 4

1 Function

Rating   Name   Duplication   Size   Complexity  
A main() 0 52 4
1
"""
2
This is an example of a very particular case in which some modalities might be
3
missing for some of the subjects, as in
4
5
    Dorent et al. 2019, Hetero-Modal Variational Encoder-Decoder
6
    for Joint Modality Completion and Segmentation
7
8
"""
9
10
import torch.nn as nn
11
from torch.utils.data import DataLoader
12
13
import torchio
14
from torchio import Image, Subject, ImagesDataset, Queue
15
from torchio.data import ImageSampler
16
17
def main():
18
    # Define training and patches sampling parameters
19
    num_epochs = 20
20
    patch_size = 128
21
    queue_length = 100
22
    samples_per_volume = 5
23
    batch_size = 2
24
25
    # Populate a list with images
26
    one_subject = Subject(
27
        T1=Image('../BRATS2018_crop_renamed/LGG75_T1.nii.gz', torchio.INTENSITY),
28
        T2=Image('../BRATS2018_crop_renamed/LGG75_T2.nii.gz', torchio.INTENSITY),
29
        label=Image('../BRATS2018_crop_renamed/LGG75_Label.nii.gz', torchio.LABEL),
30
    )
31
32
    # This subject doesn't have a T2 MRI!
33
    another_subject = Subject(
34
        T1=Image('../BRATS2018_crop_renamed/LGG74_T1.nii.gz', torchio.INTENSITY),
35
        label=Image('../BRATS2018_crop_renamed/LGG74_Label.nii.gz', torchio.LABEL),
36
    )
37
38
    subjects = [
39
        one_subject,
40
        another_subject,
41
    ]
42
43
    subjects_dataset = ImagesDataset(subjects)
44
    queue_dataset = Queue(
45
        subjects_dataset,
46
        queue_length,
47
        samples_per_volume,
48
        patch_size,
49
        ImageSampler,
50
    )
51
52
    # This collate_fn is needed in the case of missing modalities
53
    # In this case, the batch will be composed by a *list* of samples instead of
54
    # the typical Python dictionary that is collated by default in Pytorch
55
    batch_loader = DataLoader(
56
        queue_dataset,
57
        batch_size=batch_size,
58
        collate_fn=lambda x: x,
59
    )
60
61
    # Mock PyTorch model
62
    model = nn.Identity()
63
64
    for epoch_index in range(num_epochs):
65
        for batch in batch_loader:  # batch is a *list* here, not a dictionary
66
            logits = model(batch)
67
            print([batch[idx].keys() for idx in range(batch_size)])
68
    print()
69
70
71
if __name__ == "__main__":
72
    main()
73