Passed
Pull Request — master (#226)
by Fernando
01:17
created

tests.transforms.test_reproducibility   A

Complexity

Total Complexity 4

Size/Duplication

Total Lines 50
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 44
dl 0
loc 50
rs 10
c 0
b 0
f 0
wmc 4

4 Methods

Rating   Name   Duplication   Size   Complexity  
A TestReproducibility.test_reproducibility_no_seed() 0 7 1
A TestReproducibility.test_reproducibility_seed() 0 22 1
A TestReproducibility.random_stuff() 0 6 1
A TestReproducibility.setUp() 0 3 1
1
import torch
2
from torchio import Subject, Image
3
from torchio.transforms import RandomNoise
4
from ..utils import TorchioTestCase
5
6
7
class TestReproducibility(TorchioTestCase):
8
9
    def setUp(self):
10
        super().setUp()
11
        self.subject = Subject(img=Image(tensor=torch.ones(4, 4, 4)))
12
13
    def random_stuff(self, seed=None):
14
        transform = RandomNoise(std=(100, 100), seed=seed)
15
        transformed = transform(self.subject)
16
        value = transformed.img.data.sum().item()
17
        random_params_dict = transformed.history[0][1]
18
        return value, random_params_dict['seed']
19
20
    def test_reproducibility_no_seed(self):
21
        a, seed_a = self.random_stuff()
22
        b, seed_b = self.random_stuff()
23
        self.assertNotEqual(a, b)
24
        c, seed_c = self.random_stuff(seed_a)
25
        self.assertEqual(c, a)
26
        self.assertEqual(seed_c, seed_a)
27
28
    def test_reproducibility_seed(self):
29
        torch.manual_seed(42)
30
        a, seed_a = self.random_stuff()
31
        b, seed_b = self.random_stuff()
32
        self.assertNotEqual(a, b)
33
        c, seed_c = self.random_stuff(seed_a)
34
        self.assertEqual(c, a)
35
        self.assertEqual(seed_c, seed_a)
36
37
        torch.manual_seed(42)
38
        a2, seed_a2 = self.random_stuff()
39
        self.assertEqual(a2, a)
40
        self.assertEqual(seed_a2, seed_a)
41
        b2, seed_b2 = self.random_stuff()
42
        self.assertNotEqual(a2, b2)
43
        self.assertEqual(b2, b)
44
        self.assertEqual(seed_b2, seed_b)
45
        c2, seed_c2 = self.random_stuff(seed_a2)
46
        self.assertEqual(c2, a2)
47
        self.assertEqual(seed_c2, seed_a2)
48
        self.assertEqual(c2, c)
49
        self.assertEqual(seed_c2, seed_c)
50