1
|
|
|
from __future__ import annotations |
2
|
|
|
|
3
|
|
|
import copy |
4
|
|
|
import pprint |
5
|
|
|
from collections.abc import Sequence |
6
|
|
|
from typing import TYPE_CHECKING |
7
|
|
|
from typing import Any |
8
|
|
|
|
9
|
|
|
import numpy as np |
10
|
|
|
|
11
|
|
|
from ..constants import INTENSITY |
12
|
|
|
from ..constants import TYPE |
13
|
|
|
from ..utils import get_subclasses |
14
|
|
|
from .image import Image |
15
|
|
|
|
16
|
|
|
if TYPE_CHECKING: |
17
|
|
|
from ..transforms import Compose |
18
|
|
|
from ..transforms import Transform |
19
|
|
|
|
20
|
|
|
|
21
|
|
|
class Subject(dict): |
22
|
|
|
"""Class to store information about the images corresponding to a subject. |
23
|
|
|
|
24
|
|
|
Args: |
25
|
|
|
*args: If provided, a dictionary of items. |
26
|
|
|
**kwargs: Items that will be added to the subject sample. |
27
|
|
|
|
28
|
|
|
Example: |
29
|
|
|
|
30
|
|
|
>>> import torchio as tio |
31
|
|
|
>>> # One way: |
32
|
|
|
>>> subject = tio.Subject( |
33
|
|
|
... one_image=tio.ScalarImage('path_to_image.nii.gz'), |
34
|
|
|
... a_segmentation=tio.LabelMap('path_to_seg.nii.gz'), |
35
|
|
|
... age=45, |
36
|
|
|
... name='John Doe', |
37
|
|
|
... hospital='Hospital Juan Negrín', |
38
|
|
|
... ) |
39
|
|
|
>>> # If you want to create the mapping before, or have spaces in the keys: |
40
|
|
|
>>> subject_dict = { |
41
|
|
|
... 'one image': tio.ScalarImage('path_to_image.nii.gz'), |
42
|
|
|
... 'a segmentation': tio.LabelMap('path_to_seg.nii.gz'), |
43
|
|
|
... 'age': 45, |
44
|
|
|
... 'name': 'John Doe', |
45
|
|
|
... 'hospital': 'Hospital Juan Negrín', |
46
|
|
|
... } |
47
|
|
|
>>> subject = tio.Subject(subject_dict) |
48
|
|
|
""" |
49
|
|
|
|
50
|
|
|
def __init__(self, *args, **kwargs: dict[str, Any]): |
51
|
|
|
if args: |
52
|
|
|
if len(args) == 1 and isinstance(args[0], dict): |
53
|
|
|
kwargs.update(args[0]) |
54
|
|
|
else: |
55
|
|
|
message = 'Only one dictionary as positional argument is allowed' |
56
|
|
|
raise ValueError(message) |
57
|
|
|
super().__init__(**kwargs) |
58
|
|
|
self._parse_images(self.get_images(intensity_only=False)) |
59
|
|
|
self.update_attributes() # this allows me to do e.g. subject.t1 |
60
|
|
|
self.applied_transforms: list[tuple[str, dict]] = [] |
61
|
|
|
|
62
|
|
|
def __repr__(self): |
63
|
|
|
num_images = len(self.get_images(intensity_only=False)) |
64
|
|
|
string = ( |
65
|
|
|
f'{self.__class__.__name__}' |
66
|
|
|
f'(Keys: {tuple(self.keys())}; images: {num_images})' |
67
|
|
|
) |
68
|
|
|
return string |
69
|
|
|
|
70
|
|
|
def __len__(self): |
71
|
|
|
return len(self.get_images(intensity_only=False)) |
72
|
|
|
|
73
|
|
|
def __getitem__(self, item): |
74
|
|
|
if isinstance(item, (slice, int, tuple)): |
75
|
|
|
try: |
76
|
|
|
self.check_consistent_spatial_shape() |
77
|
|
|
except RuntimeError as e: |
78
|
|
|
message = ( |
79
|
|
|
'To use indexing, all images in the subject must have the' |
80
|
|
|
' same spatial shape' |
81
|
|
|
) |
82
|
|
|
raise RuntimeError(message) from e |
83
|
|
|
copied = copy.deepcopy(self) |
84
|
|
|
for image_name, image in copied.items(): |
85
|
|
|
copied[image_name] = image[item] |
86
|
|
|
return copied |
87
|
|
|
else: |
88
|
|
|
return super().__getitem__(item) |
89
|
|
|
|
90
|
|
|
@staticmethod |
91
|
|
|
def _parse_images(images: list[Image]) -> None: |
92
|
|
|
# Check that it's not empty |
93
|
|
|
if not images: |
94
|
|
|
raise TypeError('A subject without images cannot be created') |
95
|
|
|
|
96
|
|
|
@property |
97
|
|
|
def shape(self): |
98
|
|
|
"""Return shape of first image in subject. |
99
|
|
|
|
100
|
|
|
Consistency of shapes across images in the subject is checked first. |
101
|
|
|
|
102
|
|
|
Example: |
103
|
|
|
|
104
|
|
|
>>> import torchio as tio |
105
|
|
|
>>> colin = tio.datasets.Colin27() |
106
|
|
|
>>> colin.shape |
107
|
|
|
(1, 181, 217, 181) |
108
|
|
|
""" |
109
|
|
|
self.check_consistent_attribute('shape') |
110
|
|
|
return self.get_first_image().shape |
111
|
|
|
|
112
|
|
|
@property |
113
|
|
|
def spatial_shape(self): |
114
|
|
|
"""Return spatial shape of first image in subject. |
115
|
|
|
|
116
|
|
|
Consistency of spatial shapes across images in the subject is checked |
117
|
|
|
first. |
118
|
|
|
|
119
|
|
|
Example: |
120
|
|
|
|
121
|
|
|
>>> import torchio as tio |
122
|
|
|
>>> colin = tio.datasets.Colin27() |
123
|
|
|
>>> colin.spatial_shape |
124
|
|
|
(181, 217, 181) |
125
|
|
|
""" |
126
|
|
|
self.check_consistent_spatial_shape() |
127
|
|
|
return self.get_first_image().spatial_shape |
128
|
|
|
|
129
|
|
|
@property |
130
|
|
|
def spacing(self): |
131
|
|
|
"""Return spacing of first image in subject. |
132
|
|
|
|
133
|
|
|
Consistency of spacings across images in the subject is checked first. |
134
|
|
|
|
135
|
|
|
Example: |
136
|
|
|
|
137
|
|
|
>>> import torchio as tio |
138
|
|
|
>>> colin = tio.datasets.Slicer() |
139
|
|
|
>>> colin.spacing |
140
|
|
|
(1.0, 1.0, 1.2999954223632812) |
141
|
|
|
""" |
142
|
|
|
self.check_consistent_attribute('spacing') |
143
|
|
|
return self.get_first_image().spacing |
144
|
|
|
|
145
|
|
|
@property |
146
|
|
|
def history(self): |
147
|
|
|
# Kept for backwards compatibility |
148
|
|
|
return self.get_applied_transforms() |
149
|
|
|
|
150
|
|
|
def is_2d(self): |
151
|
|
|
return all(i.is_2d() for i in self.get_images(intensity_only=False)) |
152
|
|
|
|
153
|
|
|
def get_applied_transforms( |
154
|
|
|
self, |
155
|
|
|
ignore_intensity: bool = False, |
156
|
|
|
image_interpolation: str | None = None, |
157
|
|
|
) -> list[Transform]: |
158
|
|
|
from ..transforms.intensity_transform import IntensityTransform |
159
|
|
|
from ..transforms.transform import Transform |
160
|
|
|
|
161
|
|
|
name_to_transform = {cls.__name__: cls for cls in get_subclasses(Transform)} |
162
|
|
|
transforms_list = [] |
163
|
|
|
for transform_name, arguments in self.applied_transforms: |
164
|
|
|
transform = name_to_transform[transform_name](**arguments) |
165
|
|
|
if ignore_intensity and isinstance(transform, IntensityTransform): |
166
|
|
|
continue |
167
|
|
|
resamples = hasattr(transform, 'image_interpolation') |
168
|
|
|
if resamples and image_interpolation is not None: |
169
|
|
|
parsed = transform.parse_interpolation(image_interpolation) |
170
|
|
|
transform.image_interpolation = parsed |
171
|
|
|
transforms_list.append(transform) |
172
|
|
|
return transforms_list |
173
|
|
|
|
174
|
|
|
def get_composed_history( |
175
|
|
|
self, |
176
|
|
|
ignore_intensity: bool = False, |
177
|
|
|
image_interpolation: str | None = None, |
178
|
|
|
) -> Compose: |
179
|
|
|
from ..transforms.augmentation.composition import Compose |
180
|
|
|
|
181
|
|
|
transforms = self.get_applied_transforms( |
182
|
|
|
ignore_intensity=ignore_intensity, |
183
|
|
|
image_interpolation=image_interpolation, |
184
|
|
|
) |
185
|
|
|
return Compose(transforms) |
186
|
|
|
|
187
|
|
|
def get_inverse_transform( |
188
|
|
|
self, |
189
|
|
|
warn: bool = True, |
190
|
|
|
ignore_intensity: bool = False, |
191
|
|
|
image_interpolation: str | None = None, |
192
|
|
|
) -> Compose: |
193
|
|
|
"""Get a reversed list of the inverses of the applied transforms. |
194
|
|
|
|
195
|
|
|
Args: |
196
|
|
|
warn: Issue a warning if some transforms are not invertible. |
197
|
|
|
ignore_intensity: If ``True``, all instances of |
198
|
|
|
:class:`~torchio.transforms.intensity_transform.IntensityTransform` |
199
|
|
|
will be ignored. |
200
|
|
|
image_interpolation: Modify interpolation for scalar images inside |
201
|
|
|
transforms that perform resampling. |
202
|
|
|
""" |
203
|
|
|
history_transform = self.get_composed_history( |
204
|
|
|
ignore_intensity=ignore_intensity, |
205
|
|
|
image_interpolation=image_interpolation, |
206
|
|
|
) |
207
|
|
|
inverse_transform = history_transform.inverse(warn=warn) |
208
|
|
|
return inverse_transform |
209
|
|
|
|
210
|
|
|
def apply_inverse_transform(self, **kwargs) -> Subject: |
211
|
|
|
"""Apply the inverse of all applied transforms, in reverse order. |
212
|
|
|
|
213
|
|
|
Args: |
214
|
|
|
**kwargs: Keyword arguments passed on to |
215
|
|
|
:meth:`~torchio.data.subject.Subject.get_inverse_transform`. |
216
|
|
|
""" |
217
|
|
|
inverse_transform = self.get_inverse_transform(**kwargs) |
218
|
|
|
transformed: Subject |
219
|
|
|
transformed = inverse_transform(self) # type: ignore[assignment] |
220
|
|
|
transformed.clear_history() |
221
|
|
|
return transformed |
222
|
|
|
|
223
|
|
|
def clear_history(self) -> None: |
224
|
|
|
self.applied_transforms = [] |
225
|
|
|
|
226
|
|
|
def check_consistent_attribute( |
227
|
|
|
self, |
228
|
|
|
attribute: str, |
229
|
|
|
relative_tolerance: float = 1e-6, |
230
|
|
|
absolute_tolerance: float = 1e-6, |
231
|
|
|
message: str | None = None, |
232
|
|
|
) -> None: |
233
|
|
|
r"""Check for consistency of an attribute across all images. |
234
|
|
|
|
235
|
|
|
Args: |
236
|
|
|
attribute: Name of the image attribute to check |
237
|
|
|
relative_tolerance: Relative tolerance for :func:`numpy.allclose()` |
238
|
|
|
absolute_tolerance: Absolute tolerance for :func:`numpy.allclose()` |
239
|
|
|
|
240
|
|
|
Example: |
241
|
|
|
>>> import numpy as np |
242
|
|
|
>>> import torch |
243
|
|
|
>>> import torchio as tio |
244
|
|
|
>>> scalars = torch.randn(1, 512, 512, 100) |
245
|
|
|
>>> mask = torch.tensor(scalars > 0).type(torch.int16) |
246
|
|
|
>>> af1 = np.eye([0.8, 0.8, 2.50000000000001, 1]) |
247
|
|
|
>>> af2 = np.eye([0.8, 0.8, 2.49999999999999, 1]) # small difference here (e.g. due to different reader) |
248
|
|
|
>>> subject = tio.Subject( |
249
|
|
|
... image = tio.ScalarImage(tensor=scalars, affine=af1), |
250
|
|
|
... mask = tio.LabelMap(tensor=mask, affine=af2) |
251
|
|
|
... ) |
252
|
|
|
>>> subject.check_consistent_attribute('spacing') # no error as tolerances are > 0 |
253
|
|
|
|
254
|
|
|
.. note:: To check that all values for a specific attribute are close |
255
|
|
|
between all images in the subject, :func:`numpy.allclose()` is used. |
256
|
|
|
This function returns ``True`` if |
257
|
|
|
:math:`|a_i - b_i| \leq t_{abs} + t_{rel} * |b_i|`, where |
258
|
|
|
:math:`a_i` and :math:`b_i` are the :math:`i`-th element of the same |
259
|
|
|
attribute of two images being compared, |
260
|
|
|
:math:`t_{abs}` is the ``absolute_tolerance`` and |
261
|
|
|
:math:`t_{rel}` is the ``relative_tolerance``. |
262
|
|
|
""" |
263
|
|
|
message = ( |
264
|
|
|
f'More than one value for "{attribute}" found in subject images:\n{{}}' |
265
|
|
|
) |
266
|
|
|
|
267
|
|
|
names_images = self.get_images_dict(intensity_only=False).items() |
268
|
|
|
try: |
269
|
|
|
first_attribute = None |
270
|
|
|
first_image = None |
271
|
|
|
|
272
|
|
|
for image_name, image in names_images: |
273
|
|
|
if first_attribute is None: |
274
|
|
|
first_attribute = getattr(image, attribute) |
275
|
|
|
first_image = image_name |
276
|
|
|
continue |
277
|
|
|
current_attribute = getattr(image, attribute) |
278
|
|
|
all_close = np.allclose( |
279
|
|
|
current_attribute, |
280
|
|
|
first_attribute, |
281
|
|
|
rtol=relative_tolerance, |
282
|
|
|
atol=absolute_tolerance, |
283
|
|
|
) |
284
|
|
|
if not all_close: |
285
|
|
|
message = message.format( |
286
|
|
|
pprint.pformat( |
287
|
|
|
{ |
288
|
|
|
first_image: first_attribute, |
289
|
|
|
image_name: current_attribute, |
290
|
|
|
} |
291
|
|
|
), |
292
|
|
|
) |
293
|
|
|
raise RuntimeError(message) |
294
|
|
|
except TypeError: |
295
|
|
|
# fallback for non-numeric values |
296
|
|
|
values_dict = {} |
297
|
|
|
for image_name, image in names_images: |
298
|
|
|
values_dict[image_name] = getattr(image, attribute) |
299
|
|
|
num_unique_values = len(set(values_dict.values())) |
300
|
|
|
if num_unique_values > 1: |
301
|
|
|
message = message.format(pprint.pformat(values_dict)) |
302
|
|
|
raise RuntimeError(message) from None |
303
|
|
|
|
304
|
|
|
def check_consistent_spatial_shape(self) -> None: |
305
|
|
|
self.check_consistent_attribute('spatial_shape') |
306
|
|
|
|
307
|
|
|
def check_consistent_orientation(self) -> None: |
308
|
|
|
self.check_consistent_attribute('orientation') |
309
|
|
|
|
310
|
|
|
def check_consistent_affine(self) -> None: |
311
|
|
|
self.check_consistent_attribute('affine') |
312
|
|
|
|
313
|
|
|
def check_consistent_space(self) -> None: |
314
|
|
|
try: |
315
|
|
|
self.check_consistent_attribute('spacing') |
316
|
|
|
self.check_consistent_attribute('direction') |
317
|
|
|
self.check_consistent_attribute('origin') |
318
|
|
|
self.check_consistent_spatial_shape() |
319
|
|
|
except RuntimeError as e: |
320
|
|
|
message = ( |
321
|
|
|
'As described above, some images in the subject are not in the' |
322
|
|
|
' same space. You probably can use the transforms ToCanonical' |
323
|
|
|
' and Resample to fix this, as explained at' |
324
|
|
|
' https://github.com/TorchIO-project/torchio/issues/647#issuecomment-913025695' |
325
|
|
|
) |
326
|
|
|
raise RuntimeError(message) from e |
327
|
|
|
|
328
|
|
|
def get_images_names(self) -> list[str]: |
329
|
|
|
return list(self.get_images_dict(intensity_only=False).keys()) |
330
|
|
|
|
331
|
|
|
def get_images_dict( |
332
|
|
|
self, |
333
|
|
|
intensity_only=True, |
334
|
|
|
include: Sequence[str] | None = None, |
335
|
|
|
exclude: Sequence[str] | None = None, |
336
|
|
|
) -> dict[str, Image]: |
337
|
|
|
images = {} |
338
|
|
|
for image_name, image in self.items(): |
339
|
|
|
if not isinstance(image, Image): |
340
|
|
|
continue |
341
|
|
|
if intensity_only and not image[TYPE] == INTENSITY: |
342
|
|
|
continue |
343
|
|
|
if include is not None and image_name not in include: |
344
|
|
|
continue |
345
|
|
|
if exclude is not None and image_name in exclude: |
346
|
|
|
continue |
347
|
|
|
images[image_name] = image |
348
|
|
|
return images |
349
|
|
|
|
350
|
|
|
def get_images( |
351
|
|
|
self, |
352
|
|
|
intensity_only=True, |
353
|
|
|
include: Sequence[str] | None = None, |
354
|
|
|
exclude: Sequence[str] | None = None, |
355
|
|
|
) -> list[Image]: |
356
|
|
|
images_dict = self.get_images_dict( |
357
|
|
|
intensity_only=intensity_only, |
358
|
|
|
include=include, |
359
|
|
|
exclude=exclude, |
360
|
|
|
) |
361
|
|
|
return list(images_dict.values()) |
362
|
|
|
|
363
|
|
|
def get_first_image(self) -> Image: |
364
|
|
|
return self.get_images(intensity_only=False)[0] |
365
|
|
|
|
366
|
|
|
def add_transform( |
367
|
|
|
self, |
368
|
|
|
transform: Transform, |
369
|
|
|
parameters_dict: dict, |
370
|
|
|
) -> None: |
371
|
|
|
self.applied_transforms.append((transform.name, parameters_dict)) |
372
|
|
|
|
373
|
|
|
def load(self) -> None: |
374
|
|
|
"""Load images in subject on RAM.""" |
375
|
|
|
for image in self.get_images(intensity_only=False): |
376
|
|
|
image.load() |
377
|
|
|
|
378
|
|
|
def unload(self) -> None: |
379
|
|
|
"""Unload images in subject.""" |
380
|
|
|
for image in self.get_images(intensity_only=False): |
381
|
|
|
image.unload() |
382
|
|
|
|
383
|
|
|
def update_attributes(self) -> None: |
384
|
|
|
# This allows to get images using attribute notation, e.g. subject.t1 |
385
|
|
|
self.__dict__.update(self) |
386
|
|
|
|
387
|
|
|
@staticmethod |
388
|
|
|
def _check_image_name(image_name): |
389
|
|
|
if not isinstance(image_name, str): |
390
|
|
|
message = ( |
391
|
|
|
f'The image name must be a string, but it has type "{type(image_name)}"' |
392
|
|
|
) |
393
|
|
|
raise ValueError(message) |
394
|
|
|
return image_name |
395
|
|
|
|
396
|
|
|
def add_image(self, image: Image, image_name: str) -> None: |
397
|
|
|
"""Add an image to the subject instance.""" |
398
|
|
|
if not isinstance(image, Image): |
399
|
|
|
message = ( |
400
|
|
|
'Image must be an instance of torchio.Image,' |
401
|
|
|
f' but its type is "{type(image)}"' |
402
|
|
|
) |
403
|
|
|
raise ValueError(message) |
404
|
|
|
self._check_image_name(image_name) |
405
|
|
|
self[image_name] = image |
406
|
|
|
self.update_attributes() |
407
|
|
|
|
408
|
|
|
def remove_image(self, image_name: str) -> None: |
409
|
|
|
"""Remove an image from the subject instance.""" |
410
|
|
|
self._check_image_name(image_name) |
411
|
|
|
del self[image_name] |
412
|
|
|
delattr(self, image_name) |
413
|
|
|
|
414
|
|
|
def plot(self, **kwargs) -> None: |
415
|
|
|
"""Plot images using matplotlib. |
416
|
|
|
|
417
|
|
|
Args: |
418
|
|
|
**kwargs: Keyword arguments that will be passed on to |
419
|
|
|
:meth:`~torchio.Image.plot`. |
420
|
|
|
""" |
421
|
|
|
from ..visualization import plot_subject # avoid circular import |
422
|
|
|
|
423
|
|
|
plot_subject(self, **kwargs) |
424
|
|
|
|