1
|
|
|
import copy |
2
|
|
|
import pprint |
3
|
|
|
from typing import Any, Dict, List, Tuple, Optional, Sequence, TYPE_CHECKING |
4
|
|
|
|
5
|
|
|
import numpy as np |
6
|
|
|
|
7
|
|
|
from ..constants import TYPE, INTENSITY |
8
|
|
|
from .image import Image |
9
|
|
|
from ..utils import get_subclasses |
10
|
|
|
|
11
|
|
|
if TYPE_CHECKING: |
12
|
|
|
from ..transforms import Transform, Compose |
13
|
|
|
|
14
|
|
|
|
15
|
|
|
class Subject(dict): |
16
|
|
|
"""Class to store information about the images corresponding to a subject. |
17
|
|
|
|
18
|
|
|
Args: |
19
|
|
|
*args: If provided, a dictionary of items. |
20
|
|
|
**kwargs: Items that will be added to the subject sample. |
21
|
|
|
|
22
|
|
|
Example: |
23
|
|
|
|
24
|
|
|
>>> import torchio as tio |
25
|
|
|
>>> # One way: |
26
|
|
|
>>> subject = tio.Subject( |
27
|
|
|
... one_image=tio.ScalarImage('path_to_image.nii.gz'), |
28
|
|
|
... a_segmentation=tio.LabelMap('path_to_seg.nii.gz'), |
29
|
|
|
... age=45, |
30
|
|
|
... name='John Doe', |
31
|
|
|
... hospital='Hospital Juan Negrín', |
32
|
|
|
... ) |
33
|
|
|
>>> # If you want to create the mapping before, or have spaces in the keys: |
34
|
|
|
>>> subject_dict = { |
35
|
|
|
... 'one image': tio.ScalarImage('path_to_image.nii.gz'), |
36
|
|
|
... 'a segmentation': tio.LabelMap('path_to_seg.nii.gz'), |
37
|
|
|
... 'age': 45, |
38
|
|
|
... 'name': 'John Doe', |
39
|
|
|
... 'hospital': 'Hospital Juan Negrín', |
40
|
|
|
... } |
41
|
|
|
>>> subject = tio.Subject(subject_dict) |
42
|
|
|
|
43
|
|
|
""" |
44
|
|
|
|
45
|
|
|
def __init__(self, *args, **kwargs: Dict[str, Any]): |
46
|
|
|
if args: |
47
|
|
|
if len(args) == 1 and isinstance(args[0], dict): |
48
|
|
|
kwargs.update(args[0]) |
49
|
|
|
else: |
50
|
|
|
message = ( |
51
|
|
|
'Only one dictionary as positional argument is allowed') |
52
|
|
|
raise ValueError(message) |
53
|
|
|
super().__init__(**kwargs) |
54
|
|
|
self._parse_images(self.get_images(intensity_only=False)) |
55
|
|
|
self.update_attributes() # this allows me to do e.g. subject.t1 |
56
|
|
|
self.applied_transforms = [] |
57
|
|
|
|
58
|
|
|
def __repr__(self): |
59
|
|
|
num_images = len(self.get_images(intensity_only=False)) |
60
|
|
|
string = ( |
61
|
|
|
f'{self.__class__.__name__}' |
62
|
|
|
f'(Keys: {tuple(self.keys())}; images: {num_images})' |
63
|
|
|
) |
64
|
|
|
return string |
65
|
|
|
|
66
|
|
|
def __copy__(self): |
67
|
|
|
result_dict = {} |
68
|
|
|
for key, value in self.items(): |
69
|
|
|
if isinstance(value, Image): |
70
|
|
|
value = copy.copy(value) |
71
|
|
|
else: |
72
|
|
|
value = copy.deepcopy(value) |
73
|
|
|
result_dict[key] = value |
74
|
|
|
new = Subject(result_dict) |
75
|
|
|
new.applied_transforms = self.applied_transforms[:] |
76
|
|
|
return new |
77
|
|
|
|
78
|
|
|
def __len__(self): |
79
|
|
|
return len(self.get_images(intensity_only=False)) |
80
|
|
|
|
81
|
|
|
@staticmethod |
82
|
|
|
def _parse_images(images: List[Tuple[str, Image]]) -> None: |
83
|
|
|
# Check that it's not empty |
84
|
|
|
if not images: |
85
|
|
|
raise ValueError('A subject without images cannot be created') |
86
|
|
|
|
87
|
|
|
@property |
88
|
|
|
def shape(self): |
89
|
|
|
"""Return shape of first image in subject. |
90
|
|
|
|
91
|
|
|
Consistency of shapes across images in the subject is checked first. |
92
|
|
|
|
93
|
|
|
Example:: |
94
|
|
|
|
95
|
|
|
>>> import torchio as tio |
96
|
|
|
>>> colin = tio.datasets.Colin27() |
97
|
|
|
>>> colin.shape |
98
|
|
|
(1, 181, 217, 181) |
99
|
|
|
|
100
|
|
|
""" |
101
|
|
|
self.check_consistent_attribute('shape') |
102
|
|
|
return self.get_first_image().shape |
103
|
|
|
|
104
|
|
|
@property |
105
|
|
|
def spatial_shape(self): |
106
|
|
|
"""Return spatial shape of first image in subject. |
107
|
|
|
|
108
|
|
|
Consistency of spatial shapes across images in the subject is checked |
109
|
|
|
first. |
110
|
|
|
|
111
|
|
|
Example:: |
112
|
|
|
|
113
|
|
|
>>> import torchio as tio |
114
|
|
|
>>> colin = tio.datasets.Colin27() |
115
|
|
|
>>> colin.shape |
116
|
|
|
(181, 217, 181) |
117
|
|
|
""" |
118
|
|
|
self.check_consistent_spatial_shape() |
119
|
|
|
return self.get_first_image().spatial_shape |
120
|
|
|
|
121
|
|
|
@property |
122
|
|
|
def spacing(self): |
123
|
|
|
"""Return spacing of first image in subject. |
124
|
|
|
|
125
|
|
|
Consistency of spacings across images in the subject is checked first. |
126
|
|
|
|
127
|
|
|
Example:: |
128
|
|
|
|
129
|
|
|
>>> import torchio as tio |
130
|
|
|
>>> colin = tio.datasets.Slicer() |
131
|
|
|
>>> colin.shape |
132
|
|
|
(1.0, 1.0, 1.2999954223632812) |
133
|
|
|
""" |
134
|
|
|
self.check_consistent_attribute('spacing') |
135
|
|
|
return self.get_first_image().spacing |
136
|
|
|
|
137
|
|
|
@property |
138
|
|
|
def history(self): |
139
|
|
|
# Kept for backwards compatibility |
140
|
|
|
return self.get_applied_transforms() |
141
|
|
|
|
142
|
|
|
def is_2d(self): |
143
|
|
|
return all(i.is_2d() for i in self.get_images(intensity_only=False)) |
144
|
|
|
|
145
|
|
|
def get_applied_transforms( |
146
|
|
|
self, |
147
|
|
|
ignore_intensity: bool = False, |
148
|
|
|
image_interpolation: Optional[str] = None, |
149
|
|
|
) -> List['Transform']: |
150
|
|
|
from ..transforms.transform import Transform |
151
|
|
|
from ..transforms.intensity_transform import IntensityTransform |
152
|
|
|
name_to_transform = { |
153
|
|
|
cls.__name__: cls |
154
|
|
|
for cls in get_subclasses(Transform) |
155
|
|
|
} |
156
|
|
|
transforms_list = [] |
157
|
|
|
for transform_name, arguments in self.applied_transforms: |
158
|
|
|
transform = name_to_transform[transform_name](**arguments) |
159
|
|
|
if ignore_intensity and isinstance(transform, IntensityTransform): |
160
|
|
|
continue |
161
|
|
|
resamples = hasattr(transform, 'image_interpolation') |
162
|
|
|
if resamples and image_interpolation is not None: |
163
|
|
|
parsed = transform.parse_interpolation(image_interpolation) |
164
|
|
|
transform.image_interpolation = parsed |
165
|
|
|
transforms_list.append(transform) |
166
|
|
|
return transforms_list |
167
|
|
|
|
168
|
|
|
def get_composed_history( |
169
|
|
|
self, |
170
|
|
|
ignore_intensity: bool = False, |
171
|
|
|
image_interpolation: Optional[str] = None, |
172
|
|
|
) -> 'Compose': |
173
|
|
|
from ..transforms.augmentation.composition import Compose |
174
|
|
|
transforms = self.get_applied_transforms( |
175
|
|
|
ignore_intensity=ignore_intensity, |
176
|
|
|
image_interpolation=image_interpolation, |
177
|
|
|
) |
178
|
|
|
return Compose(transforms) |
179
|
|
|
|
180
|
|
|
def get_inverse_transform( |
181
|
|
|
self, |
182
|
|
|
warn: bool = True, |
183
|
|
|
ignore_intensity: bool = True, |
184
|
|
|
image_interpolation: Optional[str] = None, |
185
|
|
|
) -> 'Compose': |
186
|
|
|
"""Get a reversed list of the inverses of the applied transforms. |
187
|
|
|
|
188
|
|
|
Args: |
189
|
|
|
warn: Issue a warning if some transforms are not invertible. |
190
|
|
|
ignore_intensity: If ``True``, all instances of |
191
|
|
|
:class:`~torchio.transforms.intensity_transform.IntensityTransform` |
192
|
|
|
will be ignored. |
193
|
|
|
image_interpolation: Modify interpolation for scalar images inside |
194
|
|
|
transforms that perform resampling. |
195
|
|
|
""" |
196
|
|
|
history_transform = self.get_composed_history( |
197
|
|
|
ignore_intensity=ignore_intensity, |
198
|
|
|
image_interpolation=image_interpolation, |
199
|
|
|
) |
200
|
|
|
inverse_transform = history_transform.inverse(warn=warn) |
201
|
|
|
return inverse_transform |
202
|
|
|
|
203
|
|
|
def apply_inverse_transform(self, **kwargs) -> 'Subject': |
204
|
|
|
"""Try to apply the inverse of all applied transforms, in reverse order. |
205
|
|
|
|
206
|
|
|
Args: |
207
|
|
|
**kwargs: Keyword arguments passed on to |
208
|
|
|
:meth:`~torchio.data.subject.Subject.get_inverse_transform`. |
209
|
|
|
""" |
210
|
|
|
inverse_transform = self.get_inverse_transform(**kwargs) |
211
|
|
|
transformed = inverse_transform(self) |
212
|
|
|
transformed.clear_history() |
213
|
|
|
return transformed |
214
|
|
|
|
215
|
|
|
def clear_history(self) -> None: |
216
|
|
|
self.applied_transforms = [] |
217
|
|
|
|
218
|
|
|
def check_consistent_attribute( |
219
|
|
|
self, |
220
|
|
|
attribute: str, |
221
|
|
|
relative_tolerance: float = 1e-6, |
222
|
|
|
absolute_tolerance: float = 1e-6, |
223
|
|
|
message: Optional[str] = None, |
224
|
|
|
) -> None: |
225
|
|
|
r"""Check for consistency of an attribute across all images. |
226
|
|
|
|
227
|
|
|
Args: |
228
|
|
|
attribute: Name of the image attribute to check |
229
|
|
|
relative_tolerance: Relative tolerance for :func:`numpy.allclose()` |
230
|
|
|
absolute_tolerance: Absolute tolerance for :func:`numpy.allclose()` |
231
|
|
|
|
232
|
|
|
Example: |
233
|
|
|
>>> import numpy as np |
234
|
|
|
>>> import torch |
235
|
|
|
>>> import torchio as tio |
236
|
|
|
>>> scalars = torch.randn(1, 512, 512, 100) |
237
|
|
|
>>> mask = torch.tensor(scalars > 0).type(torch.int16) |
238
|
|
|
>>> af1 = np.eye([0.8, 0.8, 2.50000000000001, 1]) |
239
|
|
|
>>> af2 = np.eye([0.8, 0.8, 2.49999999999999, 1]) # small difference here (e.g. due to different reader) |
240
|
|
|
>>> subject = tio.Subject( |
241
|
|
|
... image = tio.ScalarImage(tensor=scalars, affine=af1), |
242
|
|
|
... mask = tio.LabelMap(tensor=mask, affine=af2) |
243
|
|
|
... ) |
244
|
|
|
>>> subject.check_consistent_attribute('spacing') # no error as tolerances are > 0 |
245
|
|
|
|
246
|
|
|
.. note:: To check that all values for a specific attribute are close |
247
|
|
|
between all images in the subject, :func:`numpy.allclose()` is used. |
248
|
|
|
This function returns ``True`` if |
249
|
|
|
:math:`|a_i - b_i| \leq t_{abs} + t_{rel} * |b_i|`, where |
250
|
|
|
:math:`a_i` and :math:`b_i` are the :math:`i`-th element of the same |
251
|
|
|
attribute of two images being compared, |
252
|
|
|
:math:`t_{abs}` is the ``absolute_tolerance`` and |
253
|
|
|
:math:`t_{rel}` is the ``relative_tolerance``. |
254
|
|
|
""" |
255
|
|
|
message = ( |
256
|
|
|
f'More than one value for "{attribute}" found in subject images:' |
257
|
|
|
'\n{}' |
258
|
|
|
) |
259
|
|
|
|
260
|
|
|
names_images = self.get_images_dict(intensity_only=False).items() |
261
|
|
|
try: |
262
|
|
|
first_attribute = None |
263
|
|
|
first_image = None |
264
|
|
|
|
265
|
|
|
for image_name, image in names_images: |
266
|
|
|
if first_attribute is None: |
267
|
|
|
first_attribute = getattr(image, attribute) |
268
|
|
|
first_image = image_name |
269
|
|
|
continue |
270
|
|
|
current_attribute = getattr(image, attribute) |
271
|
|
|
all_close = np.allclose( |
272
|
|
|
current_attribute, |
273
|
|
|
first_attribute, |
274
|
|
|
rtol=relative_tolerance, |
275
|
|
|
atol=absolute_tolerance, |
276
|
|
|
) |
277
|
|
|
if not all_close: |
278
|
|
|
message = message.format( |
279
|
|
|
pprint.pformat({ |
280
|
|
|
first_image: first_attribute, |
281
|
|
|
image_name: current_attribute |
282
|
|
|
}), |
283
|
|
|
) |
284
|
|
|
raise RuntimeError(message) |
285
|
|
|
except TypeError: |
286
|
|
|
# fallback for non-numeric values |
287
|
|
|
values_dict = {} |
288
|
|
|
for image_name, image in names_images: |
289
|
|
|
values_dict[image_name] = getattr(image, attribute) |
290
|
|
|
num_unique_values = len(set(values_dict.values())) |
291
|
|
|
if num_unique_values > 1: |
292
|
|
|
message = message.format(pprint.pformat(values_dict)) |
293
|
|
|
raise RuntimeError(message) from None |
294
|
|
|
|
295
|
|
|
def check_consistent_spatial_shape(self) -> None: |
296
|
|
|
self.check_consistent_attribute('spatial_shape') |
297
|
|
|
|
298
|
|
|
def check_consistent_orientation(self) -> None: |
299
|
|
|
self.check_consistent_attribute('orientation') |
300
|
|
|
|
301
|
|
|
def check_consistent_affine(self) -> None: |
302
|
|
|
self.check_consistent_attribute('affine') |
303
|
|
|
|
304
|
|
|
def check_consistent_space(self) -> None: |
305
|
|
|
try: |
306
|
|
|
self.check_consistent_attribute('spacing') |
307
|
|
|
self.check_consistent_attribute('direction') |
308
|
|
|
self.check_consistent_attribute('origin') |
309
|
|
|
self.check_consistent_spatial_shape() |
310
|
|
|
except RuntimeError as e: |
311
|
|
|
message = ( |
312
|
|
|
'As described above, some images in the subject are not in the' |
313
|
|
|
' same space. You probably can use the transforms ToCanonical' |
314
|
|
|
' and Resample to fix this, as explained at' |
315
|
|
|
' https://github.com/fepegar/torchio/issues/647#issuecomment-913025695' |
316
|
|
|
) |
317
|
|
|
raise RuntimeError(message) from e |
318
|
|
|
|
319
|
|
|
def get_images_names(self) -> List[str]: |
320
|
|
|
return list(self.get_images_dict(intensity_only=False).keys()) |
321
|
|
|
|
322
|
|
|
def get_images_dict( |
323
|
|
|
self, |
324
|
|
|
intensity_only=True, |
325
|
|
|
include: Optional[Sequence[str]] = None, |
326
|
|
|
exclude: Optional[Sequence[str]] = None, |
327
|
|
|
) -> Dict[str, Image]: |
328
|
|
|
images = {} |
329
|
|
|
for image_name, image in self.items(): |
330
|
|
|
if not isinstance(image, Image): |
331
|
|
|
continue |
332
|
|
|
if intensity_only and not image[TYPE] == INTENSITY: |
333
|
|
|
continue |
334
|
|
|
if include is not None and image_name not in include: |
335
|
|
|
continue |
336
|
|
|
if exclude is not None and image_name in exclude: |
337
|
|
|
continue |
338
|
|
|
images[image_name] = image |
339
|
|
|
return images |
340
|
|
|
|
341
|
|
|
def get_images( |
342
|
|
|
self, |
343
|
|
|
intensity_only=True, |
344
|
|
|
include: Optional[Sequence[str]] = None, |
345
|
|
|
exclude: Optional[Sequence[str]] = None, |
346
|
|
|
) -> List[Image]: |
347
|
|
|
images_dict = self.get_images_dict( |
348
|
|
|
intensity_only=intensity_only, |
349
|
|
|
include=include, |
350
|
|
|
exclude=exclude, |
351
|
|
|
) |
352
|
|
|
return list(images_dict.values()) |
353
|
|
|
|
354
|
|
|
def get_first_image(self) -> Image: |
355
|
|
|
return self.get_images(intensity_only=False)[0] |
356
|
|
|
|
357
|
|
|
# flake8: noqa: F821 |
358
|
|
|
def add_transform( |
359
|
|
|
self, |
360
|
|
|
transform: 'Transform', |
361
|
|
|
parameters_dict: dict, |
362
|
|
|
) -> None: |
363
|
|
|
self.applied_transforms.append((transform.name, parameters_dict)) |
364
|
|
|
|
365
|
|
|
def load(self) -> None: |
366
|
|
|
"""Load images in subject on RAM.""" |
367
|
|
|
for image in self.get_images(intensity_only=False): |
368
|
|
|
image.load() |
369
|
|
|
|
370
|
|
|
def update_attributes(self) -> None: |
371
|
|
|
# This allows to get images using attribute notation, e.g. subject.t1 |
372
|
|
|
self.__dict__.update(self) |
373
|
|
|
|
374
|
|
|
def add_image(self, image: Image, image_name: str) -> None: |
375
|
|
|
"""Add an image.""" |
376
|
|
|
self[image_name] = image |
377
|
|
|
self.update_attributes() |
378
|
|
|
|
379
|
|
|
def remove_image(self, image_name: str) -> None: |
380
|
|
|
"""Remove an image.""" |
381
|
|
|
del self[image_name] |
382
|
|
|
delattr(self, image_name) |
383
|
|
|
|
384
|
|
|
def plot(self, **kwargs) -> None: |
385
|
|
|
"""Plot images using matplotlib. |
386
|
|
|
|
387
|
|
|
Args: |
388
|
|
|
**kwargs: Keyword arguments that will be passed on to |
389
|
|
|
:meth:`~torchio.Image.plot`. |
390
|
|
|
""" |
391
|
|
|
from ..visualization import plot_subject # avoid circular import |
392
|
|
|
plot_subject(self, **kwargs) |
393
|
|
|
|