Passed
Pull Request — master (#287)
by Fernando
01:37
created

TestReproducibility.test_rng_state()   A

Complexity

Conditions 1

Size

Total Lines 12
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 12
nop 1
dl 0
loc 12
rs 9.8
c 0
b 0
f 0
1
import warnings
2
import torch
3
import torchio
4
import numpy as np
5
from torchio import Subject, Image, INTENSITY, DATA
6
from torchio.transforms import RandomNoise, compose_from_history, Compose, RandomSpike
7
from ..utils import TorchioTestCase
8
9
10
class TestReproducibility(TorchioTestCase):
11
12
    def setUp(self):
13
        super().setUp()
14
15
    def random_stuff(self, seed=42):
16
        transform = RandomNoise(std=(100, 100))#, seed=seed)
17
        transformed = transform(self.subject, seed=seed)
18
        value = transformed.img.data.sum().item()
19
        #_, seed = transformed.get_applied_transforms()[0]
20
        seed = transformed.history[0][1]["seed"] #["RandomNoise"]["seed"]
21
        return value, seed
22
23
    def test_rng_state(self):
24
        trsfm = RandomNoise()
25
        subject1, subject2 = Subject(img=Image(tensor=torch.ones(1, 4, 4, 4))), Subject(img=Image(tensor=torch.ones(1, 4, 4, 4)))
26
        transformed1 = trsfm(subject1)
27
        seed1 = transformed1.history[0][1]["seed"]
28
        value1_torch, value1_np = torch.rand(1).item(), np.random.rand()
29
        transformed2 = trsfm(subject2, seed=seed1)
30
        value2_torch, value2_np = torch.rand(1).item(), np.random.rand()
31
        data1, data2 = transformed1["img"][DATA], transformed2["img"][DATA]
32
        self.assertNotEqual(value1_torch, value2_torch)
33
        self.assertNotEqual(value1_np, value2_np)
34
        self.assertTensorEqual(data1, data2)
35
36
    def test_reproducibility_seed(self):
37
        trsfm = RandomNoise()
38
        subject1, subject2 = Subject(img=Image(tensor=torch.ones(1, 4, 4, 4))), Subject(img=Image(tensor=torch.ones(1, 4, 4, 4)))
39
        transformed1 = trsfm(subject1)
40
        seed1 = transformed1.history[0][1]["seed"]
41
        transformed2 = trsfm(subject2, seed=seed1)
42
        data1, data2 = transformed1["img"][DATA], transformed2["img"][DATA]
43
        seed2 = transformed2.history[0][1]["seed"]
44
        self.assertTensorEqual(data1, data2)
45
        self.assertEqual(seed1, seed2)
46
47
    def test_reproducibility_no_seed(self):
48
        trsfm = RandomNoise()
49
        subject1, subject2 = Subject(img=Image(tensor=torch.ones(1, 4, 4, 4))), Subject(img=Image(tensor=torch.ones(1, 4, 4, 4)))
50
        transformed1 = trsfm(subject1)
51
        transformed2 = trsfm(subject2)
52
        data1, data2 = transformed1["img"][DATA], transformed2["img"][DATA]
53
        seed1, seed2 = transformed1.history[0][1]["seed"], transformed2.history[0][1]["seed"]
54
        self.assertNotEqual(seed1, seed2)
55
        self.assertTensorNotEqual(data1, data2)
56
57
    def test_reproducibility_from_history(self):
58
        trsfm = RandomNoise()
59
        subject1, subject2 = Subject(img=Image(tensor=torch.ones(1, 4, 4, 4))), Subject(img=Image(tensor=torch.ones(1, 4, 4, 4)))
60
        transformed1 = trsfm(subject1)
61
        history1 = transformed1.history
62
        compose_hist, seeds_hist = compose_from_history(history=history1)
63
        transformed2 = compose_hist(subject2, seeds=seeds_hist)
64
        data1, data2 = transformed1["img"][DATA], transformed2["img"][DATA]
65
        self.assertTensorEqual(data1, data2)
66
67
    def test_reproducibility_compose(self):
68
        trsfm = Compose([RandomNoise(p=0.0), RandomSpike(num_spikes=3, p=1.0)])
69
        subject1, subject2 = Subject(img=Image(tensor=torch.ones(1, 4, 4, 4))), Subject(img=Image(tensor=torch.ones(1, 4, 4, 4)))
70
        transformed1 = trsfm(subject1)
71
        history1 = transformed1.history
72
        compose_hist, seeds_hist = compose_from_history(history=history1)
73
        print("Compose hist: {}\nSeeds_hist: {}".format(history1, seeds_hist))
74
        transformed2 = compose_hist(subject2, seeds=seeds_hist)
75
        data1, data2 = transformed1["img"][DATA], transformed2["img"][DATA]
76
        self.assertTensorEqual(data1, data2)
77
78
    def test_all_random_transforms(self):
79
        sample = Subject(
80
            t1=Image(tensor=torch.rand(1, 20, 20, 20)),
81
            seg=Image(tensor=torch.rand(1, 20, 20, 20) > 1, type=INTENSITY)
82
        )
83
84
        transforms_names = [
85
            name
86
            for name in dir(torchio)
87
            if name.startswith('Random')
88
        ]
89
90
        #Downsample at the end so that the image shape is not modified
91
        transforms_names.remove('RandomDownsample')
92
        transforms_names.append('RandomDownsample')
93
94
        transforms = []
95
        for transform_name in transforms_names:
96
            if transform_name is "RandomLabelsToImage": #Only transform needing an argument for __init__
97
                transform = getattr(torchio, transform_name)(label_key="seg")
98
            else:
99
                transform = getattr(torchio, transform_name)()
100
            transforms.append(transform)
101
        composed_transform = torchio.Compose(transforms)
102
        with warnings.catch_warnings():  # ignore elastic deformation warning
103
            warnings.simplefilter('ignore', UserWarning)
104
            transformed = composed_transform(sample)
105
106
        new_transforms = []
107
        seeds = []
108
109
        for transform_name, params_dict in transformed.history:
110
            if transform_name in ["Resample", "Compose"]: #The resample in the history comes from the DownSampling
111
                continue
112
            transform_class = getattr(torchio, transform_name)
113
114
            if transform_name is "RandomLabelsToImage":
115
                transform = transform_class(label_key="seg")
116
            else:
117
                transform = transform_class()
118
            new_transforms.append(transform)
119
            seeds.append(params_dict['seed'])
120
121
        composed_transform = torchio.Compose(new_transforms)
122
        with warnings.catch_warnings():  # ignore elastic deformation warning
123
            warnings.simplefilter('ignore', UserWarning)
124
            new_transformed = composed_transform(sample, seeds=seeds)
125
        self.assertTensorEqual(transformed.t1.data, new_transformed.t1.data)
126
        self.assertTensorEqual(transformed.seg.data, new_transformed.seg.data)
127