|
1
|
|
|
import warnings |
|
2
|
|
|
import torch |
|
3
|
|
|
import torchio |
|
4
|
|
|
from torchio import Subject, Image, INTENSITY |
|
5
|
|
|
from torchio.transforms import RandomNoise |
|
6
|
|
|
from ..utils import TorchioTestCase |
|
7
|
|
|
|
|
8
|
|
|
|
|
9
|
|
|
class TestReproducibility(TorchioTestCase): |
|
10
|
|
|
|
|
11
|
|
|
def setUp(self): |
|
12
|
|
|
super().setUp() |
|
13
|
|
|
self.subject = Subject(img=Image(tensor=torch.ones(4, 4, 4))) |
|
14
|
|
|
|
|
15
|
|
|
def random_stuff(self, seed=None): |
|
16
|
|
|
transform = RandomNoise(std=(100, 100))#, seed=seed) |
|
17
|
|
|
transformed = transform(self.subject, seed=seed) |
|
18
|
|
|
value = transformed.img.data.sum().item() |
|
19
|
|
|
_, seed = transformed.get_applied_transforms()[0] |
|
20
|
|
|
return value, seed |
|
21
|
|
|
|
|
22
|
|
|
def test_reproducibility_no_seed(self): |
|
23
|
|
|
a, seed_a = self.random_stuff() |
|
24
|
|
|
b, seed_b = self.random_stuff() |
|
25
|
|
|
self.assertNotEqual(a, b) |
|
26
|
|
|
c, seed_c = self.random_stuff(seed_a) |
|
27
|
|
|
self.assertEqual(c, a) |
|
28
|
|
|
self.assertEqual(seed_c, seed_a) |
|
29
|
|
|
|
|
30
|
|
|
def test_reproducibility_seed(self): |
|
31
|
|
|
torch.manual_seed(42) |
|
32
|
|
|
a, seed_a = self.random_stuff() |
|
33
|
|
|
b, seed_b = self.random_stuff() |
|
34
|
|
|
self.assertNotEqual(a, b) |
|
35
|
|
|
c, seed_c = self.random_stuff(seed_a) |
|
36
|
|
|
self.assertEqual(c, a) |
|
37
|
|
|
self.assertEqual(seed_c, seed_a) |
|
38
|
|
|
|
|
39
|
|
|
torch.manual_seed(42) |
|
40
|
|
|
a2, seed_a2 = self.random_stuff() |
|
41
|
|
|
self.assertEqual(a2, a) |
|
42
|
|
|
self.assertEqual(seed_a2, seed_a) |
|
43
|
|
|
b2, seed_b2 = self.random_stuff() |
|
44
|
|
|
self.assertNotEqual(a2, b2) |
|
45
|
|
|
self.assertEqual(b2, b) |
|
46
|
|
|
self.assertEqual(seed_b2, seed_b) |
|
47
|
|
|
c2, seed_c2 = self.random_stuff(seed_a2) |
|
48
|
|
|
self.assertEqual(c2, a2) |
|
49
|
|
|
self.assertEqual(seed_c2, seed_a2) |
|
50
|
|
|
self.assertEqual(c2, c) |
|
51
|
|
|
self.assertEqual(seed_c2, seed_c) |
|
52
|
|
|
|
|
53
|
|
|
# def test_all_random_transforms(self): |
|
54
|
|
|
# sample = Subject( |
|
55
|
|
|
# t1=Image(tensor=torch.rand(20, 20, 20)), |
|
56
|
|
|
# seg=Image(tensor=torch.rand(20, 20, 20) > 1, type=INTENSITY) |
|
57
|
|
|
# ) |
|
58
|
|
|
|
|
59
|
|
|
# transforms_names = [ |
|
60
|
|
|
# name |
|
61
|
|
|
# for name in dir(torchio) |
|
62
|
|
|
# if name.startswith('Random') |
|
63
|
|
|
# ] |
|
64
|
|
|
|
|
65
|
|
|
# # Downsample at the end so that the image shape is not modified |
|
66
|
|
|
# transforms_names.remove('RandomDownsample') |
|
67
|
|
|
# transforms_names.append('RandomDownsample') |
|
68
|
|
|
|
|
69
|
|
|
# transforms = [] |
|
70
|
|
|
# for transform_name in transforms_names: |
|
71
|
|
|
# transform = getattr(torchio, transform_name)() |
|
72
|
|
|
# transforms.append(transform) |
|
73
|
|
|
# composed_transform = torchio.Compose(transforms) |
|
74
|
|
|
# with warnings.catch_warnings(): # ignore elastic deformation warning |
|
75
|
|
|
# warnings.simplefilter('ignore', UserWarning) |
|
76
|
|
|
# transformed = composed_transform(sample) |
|
77
|
|
|
|
|
78
|
|
|
# new_transforms = [] |
|
79
|
|
|
# for transform_name, params_dict in transformed.history: |
|
80
|
|
|
# transform_class = getattr(torchio, transform_name) |
|
81
|
|
|
# transform = transform_class(seed=params_dict['seed']) |
|
82
|
|
|
# new_transforms.append(transform) |
|
83
|
|
|
# composed_transform = torchio.Compose(transforms) |
|
84
|
|
|
# with warnings.catch_warnings(): # ignore elastic deformation warning |
|
85
|
|
|
# warnings.simplefilter('ignore', UserWarning) |
|
86
|
|
|
# new_transformed = composed_transform(sample) |
|
87
|
|
|
|
|
88
|
|
|
# self.assertTensorEqual(transformed.t1.data, new_transformed.t1.data) |
|
89
|
|
|
# self.assertTensorEqual(transformed.seg.data, new_transformed.seg.data) |
|
90
|
|
|
|