|
1
|
|
|
"""Tests for custom Image subclasses with transforms.""" |
|
2
|
|
|
|
|
3
|
|
|
import torch |
|
4
|
|
|
import pytest |
|
5
|
|
|
import torchio as tio |
|
6
|
|
|
|
|
7
|
|
|
|
|
8
|
|
View Code Duplication |
class HistoryImage(tio.ScalarImage): |
|
|
|
|
|
|
9
|
|
|
"""Test custom Image with required parameter.""" |
|
10
|
|
|
|
|
11
|
|
|
def __init__(self, tensor, affine, history, **kwargs): |
|
12
|
|
|
super().__init__(tensor=tensor, affine=affine, **kwargs) |
|
13
|
|
|
self.history = history |
|
14
|
|
|
|
|
15
|
|
|
def new_like(self, tensor, affine=None): |
|
16
|
|
|
return type(self)( |
|
17
|
|
|
tensor=tensor, |
|
18
|
|
|
affine=affine if affine is not None else self.affine, |
|
19
|
|
|
history=self.history, |
|
20
|
|
|
check_nans=self.check_nans, |
|
21
|
|
|
reader=self.reader, |
|
22
|
|
|
) |
|
23
|
|
|
|
|
24
|
|
|
|
|
25
|
|
|
class MetadataImage(tio.ScalarImage): |
|
26
|
|
|
"""Test custom Image with optional parameter.""" |
|
27
|
|
|
|
|
28
|
|
|
def __init__(self, tensor, affine, metadata=None, **kwargs): |
|
29
|
|
|
super().__init__(tensor=tensor, affine=affine, **kwargs) |
|
30
|
|
|
self.metadata = metadata or {} |
|
31
|
|
|
|
|
32
|
|
|
def new_like(self, tensor, affine=None): |
|
33
|
|
|
return type(self)( |
|
34
|
|
|
tensor=tensor, |
|
35
|
|
|
affine=affine if affine is not None else self.affine, |
|
36
|
|
|
metadata=self.metadata, |
|
37
|
|
|
check_nans=self.check_nans, |
|
38
|
|
|
reader=self.reader, |
|
39
|
|
|
) |
|
40
|
|
|
|
|
41
|
|
|
|
|
42
|
|
|
class TestCustomImageSubclass: |
|
43
|
|
|
"""Test suite for custom Image subclasses with transforms.""" |
|
44
|
|
|
|
|
45
|
|
|
@pytest.fixture |
|
46
|
|
|
def history_image(self): |
|
47
|
|
|
"""Create a HistoryImage for testing.""" |
|
48
|
|
|
tensor = torch.rand(1, 10, 10, 10) |
|
49
|
|
|
affine = torch.eye(4) |
|
50
|
|
|
return HistoryImage(tensor=tensor, affine=affine, history=['created']) |
|
51
|
|
|
|
|
52
|
|
|
@pytest.fixture |
|
53
|
|
|
def metadata_image(self): |
|
54
|
|
|
"""Create a MetadataImage for testing.""" |
|
55
|
|
|
tensor = torch.rand(1, 12, 12, 12) |
|
56
|
|
|
affine = torch.eye(4) |
|
57
|
|
|
return MetadataImage( |
|
58
|
|
|
tensor=tensor, affine=affine, metadata={'id': 123, 'source': 'test'} |
|
59
|
|
|
) |
|
60
|
|
|
|
|
61
|
|
|
@pytest.fixture |
|
62
|
|
|
def history_subject(self, history_image): |
|
63
|
|
|
"""Create a Subject with HistoryImage.""" |
|
64
|
|
|
return tio.Subject(image=history_image) |
|
65
|
|
|
|
|
66
|
|
|
@pytest.fixture |
|
67
|
|
|
def metadata_subject(self, metadata_image): |
|
68
|
|
|
"""Create a Subject with MetadataImage.""" |
|
69
|
|
|
return tio.Subject(image=metadata_image) |
|
70
|
|
|
|
|
71
|
|
|
def test_crop_with_history_image(self, history_subject): |
|
72
|
|
|
"""Test that Crop transform works with custom Image requiring history parameter.""" |
|
73
|
|
|
transform = tio.Crop(cropping=2) |
|
74
|
|
|
result = transform(history_subject) |
|
75
|
|
|
|
|
76
|
|
|
# Check that the result is still a HistoryImage |
|
77
|
|
|
assert isinstance(result.image, HistoryImage) |
|
78
|
|
|
|
|
79
|
|
|
# Check that custom attribute is preserved |
|
80
|
|
|
assert result.image.history == ['created'] |
|
81
|
|
|
|
|
82
|
|
|
# Check that cropping worked correctly |
|
83
|
|
|
assert result.image.shape == (1, 6, 6, 6) |
|
84
|
|
|
|
|
85
|
|
|
def test_crop_with_metadata_image(self, metadata_subject): |
|
86
|
|
|
"""Test that Crop transform works with custom Image with optional parameters.""" |
|
87
|
|
|
transform = tio.Crop(cropping=1) |
|
88
|
|
|
result = transform(metadata_subject) |
|
89
|
|
|
|
|
90
|
|
|
# Check that the result is still a MetadataImage |
|
91
|
|
|
assert isinstance(result.image, MetadataImage) |
|
92
|
|
|
|
|
93
|
|
|
# Check that custom attribute is preserved |
|
94
|
|
|
assert result.image.metadata == {'id': 123, 'source': 'test'} |
|
95
|
|
|
|
|
96
|
|
|
# Check that cropping worked correctly |
|
97
|
|
|
assert result.image.shape == (1, 10, 10, 10) |
|
98
|
|
|
|
|
99
|
|
|
def test_chained_transforms_preserve_attributes(self, history_subject): |
|
100
|
|
|
"""Test that chained transforms preserve custom attributes.""" |
|
101
|
|
|
# Chain multiple transforms |
|
102
|
|
|
transform = tio.Compose( |
|
103
|
|
|
[ |
|
104
|
|
|
tio.Crop(cropping=1), |
|
105
|
|
|
tio.Crop(cropping=1), |
|
106
|
|
|
] |
|
107
|
|
|
) |
|
108
|
|
|
|
|
109
|
|
|
result = transform(history_subject) |
|
110
|
|
|
|
|
111
|
|
|
# Check that the result is still a HistoryImage after multiple transforms |
|
112
|
|
|
assert isinstance(result.image, HistoryImage) |
|
113
|
|
|
|
|
114
|
|
|
# Check that custom attribute is preserved through the chain |
|
115
|
|
|
assert result.image.history == ['created'] |
|
116
|
|
|
|
|
117
|
|
|
# Check that both crops were applied |
|
118
|
|
|
assert result.image.shape == (1, 6, 6, 6) |
|
119
|
|
|
|
|
120
|
|
|
def test_backward_compatibility_standard_images(self): |
|
121
|
|
|
"""Test that standard Images still work with transforms.""" |
|
122
|
|
|
# Create a standard ScalarImage |
|
123
|
|
|
tensor = torch.rand(1, 10, 10, 10) |
|
124
|
|
|
affine = torch.eye(4) |
|
125
|
|
|
image = tio.ScalarImage(tensor=tensor, affine=affine) |
|
126
|
|
|
subject = tio.Subject(image=image) |
|
127
|
|
|
|
|
128
|
|
|
# Apply transform |
|
129
|
|
|
transform = tio.Crop(cropping=2) |
|
130
|
|
|
result = transform(subject) |
|
131
|
|
|
|
|
132
|
|
|
# Check that it still works |
|
133
|
|
|
assert isinstance(result.image, tio.ScalarImage) |
|
134
|
|
|
assert result.image.shape == (1, 6, 6, 6) |
|
135
|
|
|
|
|
136
|
|
|
def test_to_reference_space_with_custom_image(self, history_image): |
|
137
|
|
|
"""Test that ToReferenceSpace works with custom images.""" |
|
138
|
|
|
# Create a reference image |
|
139
|
|
|
reference_tensor = torch.rand(1, 20, 20, 20) |
|
140
|
|
|
reference_affine = torch.eye(4) |
|
141
|
|
|
reference = tio.ScalarImage(tensor=reference_tensor, affine=reference_affine) |
|
142
|
|
|
|
|
143
|
|
|
# Create embedding tensor (smaller than reference) |
|
144
|
|
|
embedding_tensor = torch.rand(1, 10, 10, 10) |
|
145
|
|
|
|
|
146
|
|
|
# Use ToReferenceSpace.from_tensor |
|
147
|
|
|
result = tio.ToReferenceSpace.from_tensor(embedding_tensor, history_image) |
|
148
|
|
|
|
|
149
|
|
|
# Check that the result preserves the custom class type |
|
150
|
|
|
assert isinstance(result, HistoryImage) |
|
151
|
|
|
|
|
152
|
|
|
# Check that custom attribute is preserved |
|
153
|
|
|
assert result.history == ['created'] |
|
154
|
|
|
|
|
155
|
|
|
def test_new_like_method_directly(self, history_image): |
|
156
|
|
|
"""Test the new_like method directly.""" |
|
157
|
|
|
new_tensor = torch.rand(1, 5, 5, 5) |
|
158
|
|
|
new_affine = torch.eye(4) * 2 |
|
159
|
|
|
|
|
160
|
|
|
# Create new image using new_like |
|
161
|
|
|
new_image = history_image.new_like(tensor=new_tensor, affine=new_affine) |
|
162
|
|
|
|
|
163
|
|
|
# Check type preservation |
|
164
|
|
|
assert isinstance(new_image, HistoryImage) |
|
165
|
|
|
|
|
166
|
|
|
# Check attribute preservation |
|
167
|
|
|
assert new_image.history == ['created'] |
|
168
|
|
|
|
|
169
|
|
|
# Check new data |
|
170
|
|
|
assert torch.equal(new_image.data, new_tensor) |
|
171
|
|
|
assert torch.allclose( |
|
172
|
|
|
torch.tensor(new_image.affine).float(), new_affine.float() |
|
173
|
|
|
) |
|
174
|
|
|
|
|
175
|
|
|
def test_new_like_with_default_affine(self, metadata_image): |
|
176
|
|
|
"""Test new_like method with default affine (None).""" |
|
177
|
|
|
new_tensor = torch.rand(1, 8, 8, 8) |
|
178
|
|
|
|
|
179
|
|
|
# Create new image using new_like with default affine |
|
180
|
|
|
new_image = metadata_image.new_like(tensor=new_tensor) |
|
181
|
|
|
|
|
182
|
|
|
# Check that original affine is used |
|
183
|
|
|
assert torch.allclose( |
|
184
|
|
|
torch.tensor(new_image.affine), torch.tensor(metadata_image.affine) |
|
185
|
|
|
) |
|
186
|
|
|
|
|
187
|
|
|
# Check attribute preservation |
|
188
|
|
|
assert new_image.metadata == {'id': 123, 'source': 'test'} |
|
189
|
|
|
|
|
190
|
|
|
def test_label_map_subclass(self): |
|
191
|
|
|
"""Test that custom LabelMap subclasses also work.""" |
|
192
|
|
|
|
|
193
|
|
|
class CustomLabelMap(tio.LabelMap): |
|
194
|
|
|
def __init__(self, tensor, affine, labels_info, **kwargs): |
|
195
|
|
|
super().__init__(tensor=tensor, affine=affine, **kwargs) |
|
196
|
|
|
self.labels_info = labels_info |
|
197
|
|
|
|
|
198
|
|
|
def new_like(self, tensor, affine=None): |
|
199
|
|
|
return type(self)( |
|
200
|
|
|
tensor=tensor, |
|
201
|
|
|
affine=affine if affine is not None else self.affine, |
|
202
|
|
|
labels_info=self.labels_info, |
|
203
|
|
|
check_nans=self.check_nans, |
|
204
|
|
|
reader=self.reader, |
|
205
|
|
|
) |
|
206
|
|
|
|
|
207
|
|
|
# Create custom label map |
|
208
|
|
|
tensor = torch.randint(0, 3, (1, 8, 8, 8)) |
|
209
|
|
|
affine = torch.eye(4) |
|
210
|
|
|
labels_info = {0: 'background', 1: 'tissue1', 2: 'tissue2'} |
|
211
|
|
|
|
|
212
|
|
|
custom_label = CustomLabelMap( |
|
213
|
|
|
tensor=tensor, affine=affine, labels_info=labels_info |
|
214
|
|
|
) |
|
215
|
|
|
subject = tio.Subject(labels=custom_label) |
|
216
|
|
|
|
|
217
|
|
|
# Apply transform |
|
218
|
|
|
transform = tio.Crop(cropping=1) |
|
219
|
|
|
result = transform(subject) |
|
220
|
|
|
|
|
221
|
|
|
# Check preservation |
|
222
|
|
|
assert isinstance(result.labels, CustomLabelMap) |
|
223
|
|
|
assert result.labels.labels_info == labels_info |
|
224
|
|
|
assert result.labels.shape == (1, 6, 6, 6) |
|
225
|
|
|
|