|
1
|
|
|
import torch |
|
2
|
|
|
from torchio import Subject, Image |
|
3
|
|
|
from torchio.transforms import RandomNoise |
|
4
|
|
|
from ..utils import TorchioTestCase |
|
5
|
|
|
|
|
6
|
|
|
|
|
7
|
|
|
class TestReproducibility(TorchioTestCase): |
|
8
|
|
|
|
|
9
|
|
|
def setUp(self): |
|
10
|
|
|
super().setUp() |
|
11
|
|
|
self.subject = Subject(img=Image(tensor=torch.ones(4, 4, 4))) |
|
12
|
|
|
|
|
13
|
|
|
def random_stuff(self, seed=None): |
|
14
|
|
|
transform = RandomNoise(std=(100, 100), seed=seed) |
|
15
|
|
|
transformed = transform(self.subject) |
|
16
|
|
|
value = transformed.img.data.sum().item() |
|
17
|
|
|
random_params_dict = transformed.history[0][1] |
|
18
|
|
|
return value, random_params_dict['seed'] |
|
19
|
|
|
|
|
20
|
|
|
def test_reproducibility_no_seed(self): |
|
21
|
|
|
a, seed_a = self.random_stuff() |
|
22
|
|
|
b, seed_b = self.random_stuff() |
|
23
|
|
|
self.assertNotEqual(a, b) |
|
24
|
|
|
c, seed_c = self.random_stuff(seed_a) |
|
25
|
|
|
self.assertEqual(c, a) |
|
26
|
|
|
self.assertEqual(seed_c, seed_a) |
|
27
|
|
|
|
|
28
|
|
|
def test_reproducibility_seed(self): |
|
29
|
|
|
torch.manual_seed(42) |
|
30
|
|
|
a, seed_a = self.random_stuff() |
|
31
|
|
|
b, seed_b = self.random_stuff() |
|
32
|
|
|
self.assertNotEqual(a, b) |
|
33
|
|
|
c, seed_c = self.random_stuff(seed_a) |
|
34
|
|
|
self.assertEqual(c, a) |
|
35
|
|
|
self.assertEqual(seed_c, seed_a) |
|
36
|
|
|
|
|
37
|
|
|
torch.manual_seed(42) |
|
38
|
|
|
a2, seed_a2 = self.random_stuff() |
|
39
|
|
|
self.assertEqual(a2, a) |
|
40
|
|
|
self.assertEqual(seed_a2, seed_a) |
|
41
|
|
|
b2, seed_b2 = self.random_stuff() |
|
42
|
|
|
self.assertNotEqual(a2, b2) |
|
43
|
|
|
self.assertEqual(b2, b) |
|
44
|
|
|
self.assertEqual(seed_b2, seed_b) |
|
45
|
|
|
c2, seed_c2 = self.random_stuff(seed_a2) |
|
46
|
|
|
self.assertEqual(c2, a2) |
|
47
|
|
|
self.assertEqual(seed_c2, seed_a2) |
|
48
|
|
|
self.assertEqual(c2, c) |
|
49
|
|
|
self.assertEqual(seed_c2, seed_c) |
|
50
|
|
|
|