Passed
Pull Request — master (#226)
by Fernando
01:12
created

TestReproducibility.setUp()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import warnings
2
import torch
3
import torchio
4
from torchio import Subject, Image, INTENSITY
5
from torchio.transforms import RandomNoise
6
from ..utils import TorchioTestCase
7
8
9
class TestReproducibility(TorchioTestCase):
10
11
    def setUp(self):
12
        super().setUp()
13
        self.subject = Subject(img=Image(tensor=torch.ones(4, 4, 4)))
14
15
    def random_stuff(self, seed=None):
16
        transform = RandomNoise(std=(100, 100))#, seed=seed)
17
        transformed = transform(self.subject, seed=seed)
18
        value = transformed.img.data.sum().item()
19
        _, seed = transformed.get_applied_transforms()[0]
20
        return value, seed
21
22
    def test_reproducibility_no_seed(self):
23
        a, seed_a = self.random_stuff()
24
        b, seed_b = self.random_stuff()
25
        self.assertNotEqual(a, b)
26
        c, seed_c = self.random_stuff(seed_a)
27
        self.assertEqual(c, a)
28
        self.assertEqual(seed_c, seed_a)
29
30
    def test_reproducibility_seed(self):
31
        torch.manual_seed(42)
32
        a, seed_a = self.random_stuff()
33
        b, seed_b = self.random_stuff()
34
        self.assertNotEqual(a, b)
35
        c, seed_c = self.random_stuff(seed_a)
36
        self.assertEqual(c, a)
37
        self.assertEqual(seed_c, seed_a)
38
39
        torch.manual_seed(42)
40
        a2, seed_a2 = self.random_stuff()
41
        self.assertEqual(a2, a)
42
        self.assertEqual(seed_a2, seed_a)
43
        b2, seed_b2 = self.random_stuff()
44
        self.assertNotEqual(a2, b2)
45
        self.assertEqual(b2, b)
46
        self.assertEqual(seed_b2, seed_b)
47
        c2, seed_c2 = self.random_stuff(seed_a2)
48
        self.assertEqual(c2, a2)
49
        self.assertEqual(seed_c2, seed_a2)
50
        self.assertEqual(c2, c)
51
        self.assertEqual(seed_c2, seed_c)
52
53
    # def test_all_random_transforms(self):
54
    #     sample = Subject(
55
    #         t1=Image(tensor=torch.rand(20, 20, 20)),
56
    #         seg=Image(tensor=torch.rand(20, 20, 20) > 1, type=INTENSITY)
57
    #     )
58
59
    #     transforms_names = [
60
    #         name
61
    #         for name in dir(torchio)
62
    #         if name.startswith('Random')
63
    #     ]
64
65
    #     # Downsample at the end so that the image shape is not modified
66
    #     transforms_names.remove('RandomDownsample')
67
    #     transforms_names.append('RandomDownsample')
68
69
    #     transforms = []
70
    #     for transform_name in transforms_names:
71
    #         transform = getattr(torchio, transform_name)()
72
    #         transforms.append(transform)
73
    #     composed_transform = torchio.Compose(transforms)
74
    #     with warnings.catch_warnings():  # ignore elastic deformation warning
75
    #         warnings.simplefilter('ignore', UserWarning)
76
    #         transformed = composed_transform(sample)
77
78
    #     new_transforms = []
79
    #     for transform_name, params_dict in transformed.history:
80
    #         transform_class = getattr(torchio, transform_name)
81
    #         transform = transform_class(seed=params_dict['seed'])
82
    #         new_transforms.append(transform)
83
    #     composed_transform = torchio.Compose(transforms)
84
    #     with warnings.catch_warnings():  # ignore elastic deformation warning
85
    #         warnings.simplefilter('ignore', UserWarning)
86
    #         new_transformed = composed_transform(sample)
87
88
    #     self.assertTensorEqual(transformed.t1.data, new_transformed.t1.data)
89
    #     self.assertTensorEqual(transformed.seg.data, new_transformed.seg.data)
90