Passed
Pull Request — master (#287)
by Fernando
01:34
created

TestReproducibility.test_reproducibility_no_seed()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 9
nop 1
dl 0
loc 9
rs 9.95
c 0
b 0
f 0
1
import warnings
2
import torch
3
import torchio
4
import numpy as np
5
from torchio import Subject, ScalarImage, LabelMap
6
from torchio.transforms import RandomNoise, compose_from_history, Compose, RandomSpike, OneOf
7
from ..utils import TorchioTestCase
8
9
10
class TestReproducibility(TorchioTestCase):
11
12
    def get_subjects(self):
13
        subject1 = Subject(img=ScalarImage(tensor=torch.ones(1, 4, 4, 4)))
14
        subject2 = Subject(img=ScalarImage(tensor=torch.ones(1, 4, 4, 4)))
15
        return subject1, subject2
16
17
    def apply_transforms(self, subject, trsfm_list, seeds_list):
18
        s = subject
19
        for trsfm, seed in zip(trsfm_list, seeds_list):
20
            if seed:
21
                s = trsfm(s, seed=seed)
22
            else:
23
                s = trsfm(s)
24
        return s
25
26
    def random_stuff(self, seed=42):
27
        transform = RandomNoise(std=(100, 100))
28
        transformed = transform(self.sample_subject, seed=seed)
29
        value = transformed.img.data.sum().item()
30
        seed = transformed.history[0][1]['seed']
31
        return value, seed
32
33
    def test_rng_state(self):
34
        trsfm = RandomNoise()
35
        subject1, subject2 = self.get_subjects()
36
        transformed1 = trsfm(subject1)
37
        seed1 = transformed1.history[0][1]['seed']
38
        value1_torch, value1_np = torch.rand(1).item(), np.random.rand()
39
        transformed2 = trsfm(subject2, seed=seed1)
40
        value2_torch, value2_np = torch.rand(1).item(), np.random.rand()
41
        data1, data2 = transformed1.img.data, transformed2.img.data
42
        self.assertNotEqual(value1_torch, value2_torch)
43
        self.assertNotEqual(value1_np, value2_np)
44
        self.assertTensorEqual(data1, data2)
45
46
    def test_reproducibility_seed(self):
47
        trsfm = RandomNoise()
48
        subject1, subject2 = self.get_subjects()
49
        transformed1 = trsfm(subject1)
50
        seed1 = transformed1.history[0][1]['seed']
51
        transformed2 = trsfm(subject2, seed=seed1)
52
        data1, data2 = transformed1.img.data, transformed2.img.data
53
        seed2 = transformed2.history[0][1]['seed']
54
        self.assertTensorEqual(data1, data2)
55
        self.assertEqual(seed1, seed2)
56
57
    def test_reproducibility_no_seed(self):
58
        trsfm = RandomNoise()
59
        subject1, subject2 = self.get_subjects()
60
        transformed1 = trsfm(subject1)
61
        transformed2 = trsfm(subject2)
62
        data1, data2 = transformed1.img.data, transformed2.img.data
63
        seed1, seed2 = transformed1.history[0][1]['seed'], transformed2.history[0][1]['seed']
64
        self.assertNotEqual(seed1, seed2)
65
        self.assertTensorNotEqual(data1, data2)
66
67
    def test_reproducibility_from_history(self):
68
        trsfm = RandomNoise()
69
        subject1, subject2 = self.get_subjects()
70
        transformed1 = trsfm(subject1)
71
        history1 = transformed1.history
72
        trsfm_hist, seeds_hist = compose_from_history(history=history1)
73
        transformed2 = self.apply_transforms(subject2, trsfm_list=trsfm_hist, seeds_list=seeds_hist)
74
        data1, data2 = transformed1.img.data, transformed2.img.data
75
        self.assertTensorEqual(data1, data2)
76
77
    def test_reproducibility_compose(self):
78
        trsfm = Compose([RandomNoise(p=0.0), RandomSpike(num_spikes=3, p=1.0)])
79
        subject1, subject2 = self.get_subjects()
80
        transformed1 = trsfm(subject1)
81
        history1 = transformed1.history
82
        trsfm_hist, seeds_hist = compose_from_history(history=history1)
83
        transformed2 = self.apply_transforms(subject2, trsfm_list=trsfm_hist, seeds_list=seeds_hist)
84
        data1, data2 = transformed1.img.data, transformed2.img.data
85
        self.assertTensorEqual(data1, data2)
86
87
    def test_reproducibility_oneof(self):
88
        subject1, subject2 = self.get_subjects()
89
        trsfm = Compose([OneOf([RandomNoise(p=1.0), RandomSpike(num_spikes=3, p=1.0)]), RandomNoise(p=.5)])
90
        transformed1 = trsfm(subject1)
91
        history1 = transformed1.history
92
        trsfm_hist, seeds_hist = compose_from_history(history=history1)
93
        transformed2 = self.apply_transforms(subject2, trsfm_list=trsfm_hist, seeds_list=seeds_hist)
94
        data1, data2 = transformed1.img.data, transformed2.img.data
95
        self.assertTensorEqual(data1, data2)
96
97
    def test_all_random_transforms(self):
98
        sample = Subject(
99
            t1=ScalarImage(tensor=torch.rand(1, 20, 20, 20)),
100
            seg=LabelMap(tensor=torch.rand(1, 20, 20, 20) > 1)
101
        )
102
103
        transforms_names = [
104
            name
105
            for name in dir(torchio)
106
            if name.startswith('Random')
107
        ]
108
109
        # Downsample at the end so that image shape is not modified
110
        transforms_names.remove('RandomDownsample')
111
        transforms_names.append('RandomDownsample')
112
113
        transforms = []
114
        for transform_name in transforms_names:
115
            # Only transform needing an argument for __init__
116
            if transform_name == 'RandomLabelsToImage':
117
                transform = getattr(torchio, transform_name)(label_key='seg')
118
            else:
119
                transform = getattr(torchio, transform_name)()
120
            transforms.append(transform)
121
        composed_transform = torchio.Compose(transforms)
122
        with warnings.catch_warnings():  # ignore elastic deformation warning
123
            warnings.simplefilter('ignore', RuntimeWarning)
124
            transformed = composed_transform(sample)
125
126
        new_transforms, seeds = compose_from_history(transformed.history)
127
        new_transformed = self.apply_transforms(subject=sample, trsfm_list=new_transforms, seeds_list=seeds)
128
        """
129
        new_transforms = []
130
        seeds = []
131
132
        for transform_name, params_dict in transformed.history:
133
            # The Resample transform in the history comes from the DownSampling
134
            if transform_name in ['Resample', 'Compose']:
135
                continue
136
            transform_class = getattr(torchio, transform_name)
137
138
            if transform_name == 'RandomLabelsToImage':
139
                transform = transform_class(label_key='seg')
140
            else:
141
                transform = transform_class()
142
            new_transforms.append(transform)
143
            seeds.append(params_dict['seed'])
144
145
        composed_transform = torchio.Compose(new_transforms)
146
        with warnings.catch_warnings():  # ignore elastic deformation warning
147
            warnings.simplefilter('ignore', RuntimeWarning)
148
            new_transformed = composed_transform(sample, seeds=seeds)
149
        """
150
151
        self.assertTensorEqual(transformed.t1.data, new_transformed.t1.data)
152
        self.assertTensorEqual(transformed.seg.data, new_transformed.seg.data)
153