Passed
Pull Request — master (#287)
by Fernando
01:18
created

TestReproducibility.random_stuff()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
import warnings
2
import torch
3
import torchio
4
import numpy as np
5
from torchio import Subject, ScalarImage, LabelMap
6
from torchio.transforms import RandomNoise, compose_from_history, Compose, RandomSpike
7
from ..utils import TorchioTestCase
8
9
10
class TestReproducibility(TorchioTestCase):
11
12
    def get_subjects(self):
13
        subject1 = Subject(img=ScalarImage(tensor=torch.ones(1, 4, 4, 4)))
14
        subject2 = Subject(img=ScalarImage(tensor=torch.ones(1, 4, 4, 4)))
15
        return subject1, subject2
16
17
    def random_stuff(self, seed=42):
18
        transform = RandomNoise(std=(100, 100))
19
        transformed = transform(self.sample_subject, seed=seed)
20
        value = transformed.img.data.sum().item()
21
        seed = transformed.history[0][1]['seed']
22
        return value, seed
23
24
    def test_rng_state(self):
25
        trsfm = RandomNoise()
26
        subject1, subject2 = self.get_subjects()
27
        transformed1 = trsfm(subject1)
28
        seed1 = transformed1.history[0][1]['seed']
29
        value1_torch, value1_np = torch.rand(1).item(), np.random.rand()
30
        transformed2 = trsfm(subject2, seed=seed1)
31
        value2_torch, value2_np = torch.rand(1).item(), np.random.rand()
32
        data1, data2 = transformed1.img.data, transformed2.img.data
33
        self.assertNotEqual(value1_torch, value2_torch)
34
        self.assertNotEqual(value1_np, value2_np)
35
        self.assertTensorEqual(data1, data2)
36
37
    def test_reproducibility_seed(self):
38
        trsfm = RandomNoise()
39
        subject1, subject2 = self.get_subjects()
40
        transformed1 = trsfm(subject1)
41
        seed1 = transformed1.history[0][1]['seed']
42
        transformed2 = trsfm(subject2, seed=seed1)
43
        data1, data2 = transformed1.img.data, transformed2.img.data
44
        seed2 = transformed2.history[0][1]['seed']
45
        self.assertTensorEqual(data1, data2)
46
        self.assertEqual(seed1, seed2)
47
48
    def test_reproducibility_no_seed(self):
49
        trsfm = RandomNoise()
50
        subject1, subject2 = self.get_subjects()
51
        transformed1 = trsfm(subject1)
52
        transformed2 = trsfm(subject2)
53
        data1, data2 = transformed1.img.data, transformed2.img.data
54
        seed1, seed2 = transformed1.history[0][1]['seed'], transformed2.history[0][1]['seed']
55
        self.assertNotEqual(seed1, seed2)
56
        self.assertTensorNotEqual(data1, data2)
57
58
    def test_reproducibility_from_history(self):
59
        trsfm = RandomNoise()
60
        subject1, subject2 = self.get_subjects()
61
        transformed1 = trsfm(subject1)
62
        history1 = transformed1.history
63
        compose_hist, seeds_hist = compose_from_history(history=history1)
64
        transformed2 = compose_hist(subject2, seeds=seeds_hist)
65
        data1, data2 = transformed1.img.data, transformed2.img.data
66
        self.assertTensorEqual(data1, data2)
67
68
    def test_reproducibility_compose(self):
69
        trsfm = Compose([RandomNoise(p=0.0), RandomSpike(num_spikes=3, p=1.0)])
70
        subject1, subject2 = self.get_subjects()
71
        transformed1 = trsfm(subject1)
72
        history1 = transformed1.history
73
        compose_hist, seeds_hist = compose_from_history(history=history1)
74
        transformed2 = compose_hist(subject2, seeds=seeds_hist)
75
        data1, data2 = transformed1.img.data, transformed2.img.data
76
        self.assertTensorEqual(data1, data2)
77
78
    def test_all_random_transforms(self):
79
        sample = Subject(
80
            t1=ScalarImage(tensor=torch.rand(1, 20, 20, 20)),
81
            seg=LabelMap(tensor=torch.rand(1, 20, 20, 20) > 1)
82
        )
83
84
        transforms_names = [
85
            name
86
            for name in dir(torchio)
87
            if name.startswith('Random')
88
        ]
89
90
        # Downsample at the end so that image shape is not modified
91
        transforms_names.remove('RandomDownsample')
92
        transforms_names.append('RandomDownsample')
93
94
        transforms = []
95
        for transform_name in transforms_names:
96
            # Only transform needing an argument for __init__
97
            if transform_name == 'RandomLabelsToImage':
98
                transform = getattr(torchio, transform_name)(label_key='seg')
99
            else:
100
                transform = getattr(torchio, transform_name)()
101
            transforms.append(transform)
102
        composed_transform = torchio.Compose(transforms)
103
        with warnings.catch_warnings():  # ignore elastic deformation warning
104
            warnings.simplefilter('ignore', RuntimeWarning)
105
            transformed = composed_transform(sample)
106
107
        new_transforms = []
108
        seeds = []
109
110
        for transform_name, params_dict in transformed.history:
111
            # The Resample transform in the history comes from the DownSampling
112
            if transform_name in ['Resample', 'Compose']:
113
                continue
114
            transform_class = getattr(torchio, transform_name)
115
116
            if transform_name == 'RandomLabelsToImage':
117
                transform = transform_class(label_key='seg')
118
            else:
119
                transform = transform_class()
120
            new_transforms.append(transform)
121
            seeds.append(params_dict['seed'])
122
123
        composed_transform = torchio.Compose(new_transforms)
124
        with warnings.catch_warnings():  # ignore elastic deformation warning
125
            warnings.simplefilter('ignore', RuntimeWarning)
126
            new_transformed = composed_transform(sample, seeds=seeds)
127
        self.assertTensorEqual(transformed.t1.data, new_transformed.t1.data)
128
        self.assertTensorEqual(transformed.seg.data, new_transformed.seg.data)
129