|
1
|
|
|
from __future__ import annotations |
|
2
|
|
|
|
|
3
|
|
|
import copy |
|
4
|
|
|
import pprint |
|
5
|
|
|
from collections.abc import Sequence |
|
6
|
|
|
from typing import TYPE_CHECKING |
|
7
|
|
|
from typing import Any |
|
8
|
|
|
|
|
9
|
|
|
import numpy as np |
|
10
|
|
|
|
|
11
|
|
|
from ..constants import INTENSITY |
|
12
|
|
|
from ..constants import TYPE |
|
13
|
|
|
from ..utils import get_subclasses |
|
14
|
|
|
from .image import Image |
|
15
|
|
|
|
|
16
|
|
|
if TYPE_CHECKING: |
|
17
|
|
|
from ..transforms import Compose |
|
18
|
|
|
from ..transforms import Transform |
|
19
|
|
|
|
|
20
|
|
|
|
|
21
|
|
|
class Subject(dict): |
|
22
|
|
|
"""Class to store information about the images corresponding to a subject. |
|
23
|
|
|
|
|
24
|
|
|
Args: |
|
25
|
|
|
*args: If provided, a dictionary of items. |
|
26
|
|
|
**kwargs: Items that will be added to the subject sample. |
|
27
|
|
|
|
|
28
|
|
|
Example: |
|
29
|
|
|
|
|
30
|
|
|
>>> import torchio as tio |
|
31
|
|
|
>>> # One way: |
|
32
|
|
|
>>> subject = tio.Subject( |
|
33
|
|
|
... one_image=tio.ScalarImage('path_to_image.nii.gz'), |
|
34
|
|
|
... a_segmentation=tio.LabelMap('path_to_seg.nii.gz'), |
|
35
|
|
|
... age=45, |
|
36
|
|
|
... name='John Doe', |
|
37
|
|
|
... hospital='Hospital Juan Negrín', |
|
38
|
|
|
... ) |
|
39
|
|
|
>>> # If you want to create the mapping before, or have spaces in the keys: |
|
40
|
|
|
>>> subject_dict = { |
|
41
|
|
|
... 'one image': tio.ScalarImage('path_to_image.nii.gz'), |
|
42
|
|
|
... 'a segmentation': tio.LabelMap('path_to_seg.nii.gz'), |
|
43
|
|
|
... 'age': 45, |
|
44
|
|
|
... 'name': 'John Doe', |
|
45
|
|
|
... 'hospital': 'Hospital Juan Negrín', |
|
46
|
|
|
... } |
|
47
|
|
|
>>> subject = tio.Subject(subject_dict) |
|
48
|
|
|
""" |
|
49
|
|
|
|
|
50
|
|
|
def __init__(self, *args, **kwargs: dict[str, Any]): |
|
51
|
|
|
if args: |
|
52
|
|
|
if len(args) == 1 and isinstance(args[0], dict): |
|
53
|
|
|
kwargs.update(args[0]) |
|
54
|
|
|
else: |
|
55
|
|
|
message = 'Only one dictionary as positional argument is allowed' |
|
56
|
|
|
raise ValueError(message) |
|
57
|
|
|
super().__init__(**kwargs) |
|
58
|
|
|
self._parse_images(self.get_images(intensity_only=False)) |
|
59
|
|
|
self.update_attributes() # this allows me to do e.g. subject.t1 |
|
60
|
|
|
self.applied_transforms: list[tuple[str, dict]] = [] |
|
61
|
|
|
|
|
62
|
|
|
def __repr__(self): |
|
63
|
|
|
num_images = len(self.get_images(intensity_only=False)) |
|
64
|
|
|
string = ( |
|
65
|
|
|
f'{self.__class__.__name__}' |
|
66
|
|
|
f'(Keys: {tuple(self.keys())}; images: {num_images})' |
|
67
|
|
|
) |
|
68
|
|
|
return string |
|
69
|
|
|
|
|
70
|
|
|
def __len__(self): |
|
71
|
|
|
return len(self.get_images(intensity_only=False)) |
|
72
|
|
|
|
|
73
|
|
|
def __getitem__(self, item): |
|
74
|
|
|
if isinstance(item, (slice, int, tuple)): |
|
75
|
|
|
try: |
|
76
|
|
|
self.check_consistent_spatial_shape() |
|
77
|
|
|
except RuntimeError as e: |
|
78
|
|
|
message = ( |
|
79
|
|
|
'To use indexing, all images in the subject must have the' |
|
80
|
|
|
' same spatial shape' |
|
81
|
|
|
) |
|
82
|
|
|
raise RuntimeError(message) from e |
|
83
|
|
|
copied = copy.deepcopy(self) |
|
84
|
|
|
for image_name, image in copied.items(): |
|
85
|
|
|
copied[image_name] = image[item] |
|
86
|
|
|
return copied |
|
87
|
|
|
else: |
|
88
|
|
|
return super().__getitem__(item) |
|
89
|
|
|
|
|
90
|
|
|
@staticmethod |
|
91
|
|
|
def _parse_images(images: list[Image]) -> None: |
|
92
|
|
|
# Check that it's not empty |
|
93
|
|
|
if not images: |
|
94
|
|
|
raise TypeError('A subject without images cannot be created') |
|
95
|
|
|
|
|
96
|
|
|
@property |
|
97
|
|
|
def shape(self): |
|
98
|
|
|
"""Return shape of first image in subject. |
|
99
|
|
|
|
|
100
|
|
|
Consistency of shapes across images in the subject is checked first. |
|
101
|
|
|
|
|
102
|
|
|
Example: |
|
103
|
|
|
|
|
104
|
|
|
>>> import torchio as tio |
|
105
|
|
|
>>> colin = tio.datasets.Colin27() |
|
106
|
|
|
>>> colin.shape |
|
107
|
|
|
(1, 181, 217, 181) |
|
108
|
|
|
""" |
|
109
|
|
|
self.check_consistent_attribute('shape') |
|
110
|
|
|
return self.get_first_image().shape |
|
111
|
|
|
|
|
112
|
|
|
@property |
|
113
|
|
|
def spatial_shape(self): |
|
114
|
|
|
"""Return spatial shape of first image in subject. |
|
115
|
|
|
|
|
116
|
|
|
Consistency of spatial shapes across images in the subject is checked |
|
117
|
|
|
first. |
|
118
|
|
|
|
|
119
|
|
|
Example: |
|
120
|
|
|
|
|
121
|
|
|
>>> import torchio as tio |
|
122
|
|
|
>>> colin = tio.datasets.Colin27() |
|
123
|
|
|
>>> colin.spatial_shape |
|
124
|
|
|
(181, 217, 181) |
|
125
|
|
|
""" |
|
126
|
|
|
self.check_consistent_spatial_shape() |
|
127
|
|
|
return self.get_first_image().spatial_shape |
|
128
|
|
|
|
|
129
|
|
|
@property |
|
130
|
|
|
def spacing(self): |
|
131
|
|
|
"""Return spacing of first image in subject. |
|
132
|
|
|
|
|
133
|
|
|
Consistency of spacings across images in the subject is checked first. |
|
134
|
|
|
|
|
135
|
|
|
Example: |
|
136
|
|
|
|
|
137
|
|
|
>>> import torchio as tio |
|
138
|
|
|
>>> colin = tio.datasets.Slicer() |
|
139
|
|
|
>>> colin.spacing |
|
140
|
|
|
(1.0, 1.0, 1.2999954223632812) |
|
141
|
|
|
""" |
|
142
|
|
|
self.check_consistent_attribute('spacing') |
|
143
|
|
|
return self.get_first_image().spacing |
|
144
|
|
|
|
|
145
|
|
|
@property |
|
146
|
|
|
def history(self): |
|
147
|
|
|
# Kept for backwards compatibility |
|
148
|
|
|
return self.get_applied_transforms() |
|
149
|
|
|
|
|
150
|
|
|
def is_2d(self): |
|
151
|
|
|
return all(i.is_2d() for i in self.get_images(intensity_only=False)) |
|
152
|
|
|
|
|
153
|
|
|
def get_applied_transforms( |
|
154
|
|
|
self, |
|
155
|
|
|
ignore_intensity: bool = False, |
|
156
|
|
|
image_interpolation: str | None = None, |
|
157
|
|
|
) -> list[Transform]: |
|
158
|
|
|
from ..transforms.intensity_transform import IntensityTransform |
|
159
|
|
|
from ..transforms.transform import Transform |
|
160
|
|
|
|
|
161
|
|
|
name_to_transform = {cls.__name__: cls for cls in get_subclasses(Transform)} |
|
162
|
|
|
transforms_list = [] |
|
163
|
|
|
for transform_name, arguments in self.applied_transforms: |
|
164
|
|
|
transform = name_to_transform[transform_name](**arguments) |
|
165
|
|
|
if ignore_intensity and isinstance(transform, IntensityTransform): |
|
166
|
|
|
continue |
|
167
|
|
|
resamples = hasattr(transform, 'image_interpolation') |
|
168
|
|
|
if resamples and image_interpolation is not None: |
|
169
|
|
|
parsed = transform.parse_interpolation(image_interpolation) |
|
170
|
|
|
transform.image_interpolation = parsed |
|
171
|
|
|
transforms_list.append(transform) |
|
172
|
|
|
return transforms_list |
|
173
|
|
|
|
|
174
|
|
|
def get_composed_history( |
|
175
|
|
|
self, |
|
176
|
|
|
ignore_intensity: bool = False, |
|
177
|
|
|
image_interpolation: str | None = None, |
|
178
|
|
|
) -> Compose: |
|
179
|
|
|
from ..transforms.augmentation.composition import Compose |
|
180
|
|
|
|
|
181
|
|
|
transforms = self.get_applied_transforms( |
|
182
|
|
|
ignore_intensity=ignore_intensity, |
|
183
|
|
|
image_interpolation=image_interpolation, |
|
184
|
|
|
) |
|
185
|
|
|
return Compose(transforms) |
|
186
|
|
|
|
|
187
|
|
|
def get_inverse_transform( |
|
188
|
|
|
self, |
|
189
|
|
|
warn: bool = True, |
|
190
|
|
|
ignore_intensity: bool = False, |
|
191
|
|
|
image_interpolation: str | None = None, |
|
192
|
|
|
) -> Compose: |
|
193
|
|
|
"""Get a reversed list of the inverses of the applied transforms. |
|
194
|
|
|
|
|
195
|
|
|
Args: |
|
196
|
|
|
warn: Issue a warning if some transforms are not invertible. |
|
197
|
|
|
ignore_intensity: If ``True``, all instances of |
|
198
|
|
|
:class:`~torchio.transforms.intensity_transform.IntensityTransform` |
|
199
|
|
|
will be ignored. |
|
200
|
|
|
image_interpolation: Modify interpolation for scalar images inside |
|
201
|
|
|
transforms that perform resampling. |
|
202
|
|
|
""" |
|
203
|
|
|
history_transform = self.get_composed_history( |
|
204
|
|
|
ignore_intensity=ignore_intensity, |
|
205
|
|
|
image_interpolation=image_interpolation, |
|
206
|
|
|
) |
|
207
|
|
|
inverse_transform = history_transform.inverse(warn=warn) |
|
208
|
|
|
return inverse_transform |
|
209
|
|
|
|
|
210
|
|
|
def apply_inverse_transform(self, **kwargs) -> Subject: |
|
211
|
|
|
"""Apply the inverse of all applied transforms, in reverse order. |
|
212
|
|
|
|
|
213
|
|
|
Args: |
|
214
|
|
|
**kwargs: Keyword arguments passed on to |
|
215
|
|
|
:meth:`~torchio.data.subject.Subject.get_inverse_transform`. |
|
216
|
|
|
""" |
|
217
|
|
|
inverse_transform = self.get_inverse_transform(**kwargs) |
|
218
|
|
|
transformed: Subject |
|
219
|
|
|
transformed = inverse_transform(self) # type: ignore[assignment] |
|
220
|
|
|
transformed.clear_history() |
|
221
|
|
|
return transformed |
|
222
|
|
|
|
|
223
|
|
|
def clear_history(self) -> None: |
|
224
|
|
|
self.applied_transforms = [] |
|
225
|
|
|
|
|
226
|
|
|
def check_consistent_attribute( |
|
227
|
|
|
self, |
|
228
|
|
|
attribute: str, |
|
229
|
|
|
relative_tolerance: float = 1e-6, |
|
230
|
|
|
absolute_tolerance: float = 1e-6, |
|
231
|
|
|
message: str | None = None, |
|
232
|
|
|
) -> None: |
|
233
|
|
|
r"""Check for consistency of an attribute across all images. |
|
234
|
|
|
|
|
235
|
|
|
Args: |
|
236
|
|
|
attribute: Name of the image attribute to check |
|
237
|
|
|
relative_tolerance: Relative tolerance for :func:`numpy.allclose()` |
|
238
|
|
|
absolute_tolerance: Absolute tolerance for :func:`numpy.allclose()` |
|
239
|
|
|
|
|
240
|
|
|
Example: |
|
241
|
|
|
>>> import numpy as np |
|
242
|
|
|
>>> import torch |
|
243
|
|
|
>>> import torchio as tio |
|
244
|
|
|
>>> scalars = torch.randn(1, 512, 512, 100) |
|
245
|
|
|
>>> mask = torch.tensor(scalars > 0).type(torch.int16) |
|
246
|
|
|
>>> af1 = np.eye([0.8, 0.8, 2.50000000000001, 1]) |
|
247
|
|
|
>>> af2 = np.eye([0.8, 0.8, 2.49999999999999, 1]) # small difference here (e.g. due to different reader) |
|
248
|
|
|
>>> subject = tio.Subject( |
|
249
|
|
|
... image = tio.ScalarImage(tensor=scalars, affine=af1), |
|
250
|
|
|
... mask = tio.LabelMap(tensor=mask, affine=af2) |
|
251
|
|
|
... ) |
|
252
|
|
|
>>> subject.check_consistent_attribute('spacing') # no error as tolerances are > 0 |
|
253
|
|
|
|
|
254
|
|
|
.. note:: To check that all values for a specific attribute are close |
|
255
|
|
|
between all images in the subject, :func:`numpy.allclose()` is used. |
|
256
|
|
|
This function returns ``True`` if |
|
257
|
|
|
:math:`|a_i - b_i| \leq t_{abs} + t_{rel} * |b_i|`, where |
|
258
|
|
|
:math:`a_i` and :math:`b_i` are the :math:`i`-th element of the same |
|
259
|
|
|
attribute of two images being compared, |
|
260
|
|
|
:math:`t_{abs}` is the ``absolute_tolerance`` and |
|
261
|
|
|
:math:`t_{rel}` is the ``relative_tolerance``. |
|
262
|
|
|
""" |
|
263
|
|
|
message = ( |
|
264
|
|
|
f'More than one value for "{attribute}" found in subject images:\n{{}}' |
|
265
|
|
|
) |
|
266
|
|
|
|
|
267
|
|
|
names_images = self.get_images_dict(intensity_only=False).items() |
|
268
|
|
|
try: |
|
269
|
|
|
first_attribute = None |
|
270
|
|
|
first_image = None |
|
271
|
|
|
|
|
272
|
|
|
for image_name, image in names_images: |
|
273
|
|
|
if first_attribute is None: |
|
274
|
|
|
first_attribute = getattr(image, attribute) |
|
275
|
|
|
first_image = image_name |
|
276
|
|
|
continue |
|
277
|
|
|
current_attribute = getattr(image, attribute) |
|
278
|
|
|
all_close = np.allclose( |
|
279
|
|
|
current_attribute, |
|
280
|
|
|
first_attribute, |
|
281
|
|
|
rtol=relative_tolerance, |
|
282
|
|
|
atol=absolute_tolerance, |
|
283
|
|
|
) |
|
284
|
|
|
if not all_close: |
|
285
|
|
|
message = message.format( |
|
286
|
|
|
pprint.pformat( |
|
287
|
|
|
{ |
|
288
|
|
|
first_image: first_attribute, |
|
289
|
|
|
image_name: current_attribute, |
|
290
|
|
|
} |
|
291
|
|
|
), |
|
292
|
|
|
) |
|
293
|
|
|
raise RuntimeError(message) |
|
294
|
|
|
except TypeError: |
|
295
|
|
|
# fallback for non-numeric values |
|
296
|
|
|
values_dict = {} |
|
297
|
|
|
for image_name, image in names_images: |
|
298
|
|
|
values_dict[image_name] = getattr(image, attribute) |
|
299
|
|
|
num_unique_values = len(set(values_dict.values())) |
|
300
|
|
|
if num_unique_values > 1: |
|
301
|
|
|
message = message.format(pprint.pformat(values_dict)) |
|
302
|
|
|
raise RuntimeError(message) from None |
|
303
|
|
|
|
|
304
|
|
|
def check_consistent_spatial_shape(self) -> None: |
|
305
|
|
|
self.check_consistent_attribute('spatial_shape') |
|
306
|
|
|
|
|
307
|
|
|
def check_consistent_orientation(self) -> None: |
|
308
|
|
|
self.check_consistent_attribute('orientation') |
|
309
|
|
|
|
|
310
|
|
|
def check_consistent_affine(self) -> None: |
|
311
|
|
|
self.check_consistent_attribute('affine') |
|
312
|
|
|
|
|
313
|
|
|
def check_consistent_space(self) -> None: |
|
314
|
|
|
try: |
|
315
|
|
|
self.check_consistent_attribute('spacing') |
|
316
|
|
|
self.check_consistent_attribute('direction') |
|
317
|
|
|
self.check_consistent_attribute('origin') |
|
318
|
|
|
self.check_consistent_spatial_shape() |
|
319
|
|
|
except RuntimeError as e: |
|
320
|
|
|
message = ( |
|
321
|
|
|
'As described above, some images in the subject are not in the' |
|
322
|
|
|
' same space. You probably can use the transforms ToCanonical' |
|
323
|
|
|
' and Resample to fix this, as explained at' |
|
324
|
|
|
' https://github.com/TorchIO-project/torchio/issues/647#issuecomment-913025695' |
|
325
|
|
|
) |
|
326
|
|
|
raise RuntimeError(message) from e |
|
327
|
|
|
|
|
328
|
|
|
def get_images_names(self) -> list[str]: |
|
329
|
|
|
return list(self.get_images_dict(intensity_only=False).keys()) |
|
330
|
|
|
|
|
331
|
|
|
def get_images_dict( |
|
332
|
|
|
self, |
|
333
|
|
|
intensity_only=True, |
|
334
|
|
|
include: Sequence[str] | None = None, |
|
335
|
|
|
exclude: Sequence[str] | None = None, |
|
336
|
|
|
) -> dict[str, Image]: |
|
337
|
|
|
images = {} |
|
338
|
|
|
for image_name, image in self.items(): |
|
339
|
|
|
if not isinstance(image, Image): |
|
340
|
|
|
continue |
|
341
|
|
|
if intensity_only and not image[TYPE] == INTENSITY: |
|
342
|
|
|
continue |
|
343
|
|
|
if include is not None and image_name not in include: |
|
344
|
|
|
continue |
|
345
|
|
|
if exclude is not None and image_name in exclude: |
|
346
|
|
|
continue |
|
347
|
|
|
images[image_name] = image |
|
348
|
|
|
return images |
|
349
|
|
|
|
|
350
|
|
|
def get_images( |
|
351
|
|
|
self, |
|
352
|
|
|
intensity_only=True, |
|
353
|
|
|
include: Sequence[str] | None = None, |
|
354
|
|
|
exclude: Sequence[str] | None = None, |
|
355
|
|
|
) -> list[Image]: |
|
356
|
|
|
images_dict = self.get_images_dict( |
|
357
|
|
|
intensity_only=intensity_only, |
|
358
|
|
|
include=include, |
|
359
|
|
|
exclude=exclude, |
|
360
|
|
|
) |
|
361
|
|
|
return list(images_dict.values()) |
|
362
|
|
|
|
|
363
|
|
|
def get_image(self, image_name: str) -> Image: |
|
364
|
|
|
"""Get a single image by its name.""" |
|
365
|
|
|
return self.get_images_dict(intensity_only=False)[image_name] |
|
366
|
|
|
|
|
367
|
|
|
def get_first_image(self) -> Image: |
|
368
|
|
|
return self.get_images(intensity_only=False)[0] |
|
369
|
|
|
|
|
370
|
|
|
def add_transform( |
|
371
|
|
|
self, |
|
372
|
|
|
transform: Transform, |
|
373
|
|
|
parameters_dict: dict, |
|
374
|
|
|
) -> None: |
|
375
|
|
|
self.applied_transforms.append((transform.name, parameters_dict)) |
|
376
|
|
|
|
|
377
|
|
|
def load(self) -> None: |
|
378
|
|
|
"""Load images in subject on RAM.""" |
|
379
|
|
|
for image in self.get_images(intensity_only=False): |
|
380
|
|
|
image.load() |
|
381
|
|
|
|
|
382
|
|
|
def unload(self) -> None: |
|
383
|
|
|
"""Unload images in subject.""" |
|
384
|
|
|
for image in self.get_images(intensity_only=False): |
|
385
|
|
|
image.unload() |
|
386
|
|
|
|
|
387
|
|
|
def update_attributes(self) -> None: |
|
388
|
|
|
# This allows to get images using attribute notation, e.g. subject.t1 |
|
389
|
|
|
self.__dict__.update(self) |
|
390
|
|
|
|
|
391
|
|
|
@staticmethod |
|
392
|
|
|
def _check_image_name(image_name): |
|
393
|
|
|
if not isinstance(image_name, str): |
|
394
|
|
|
message = ( |
|
395
|
|
|
f'The image name must be a string, but it has type "{type(image_name)}"' |
|
396
|
|
|
) |
|
397
|
|
|
raise ValueError(message) |
|
398
|
|
|
return image_name |
|
399
|
|
|
|
|
400
|
|
|
def add_image(self, image: Image, image_name: str) -> None: |
|
401
|
|
|
"""Add an image to the subject instance.""" |
|
402
|
|
|
if not isinstance(image, Image): |
|
403
|
|
|
message = ( |
|
404
|
|
|
'Image must be an instance of torchio.Image,' |
|
405
|
|
|
f' but its type is "{type(image)}"' |
|
406
|
|
|
) |
|
407
|
|
|
raise ValueError(message) |
|
408
|
|
|
self._check_image_name(image_name) |
|
409
|
|
|
self[image_name] = image |
|
410
|
|
|
self.update_attributes() |
|
411
|
|
|
|
|
412
|
|
|
def remove_image(self, image_name: str) -> None: |
|
413
|
|
|
"""Remove an image from the subject instance.""" |
|
414
|
|
|
self._check_image_name(image_name) |
|
415
|
|
|
del self[image_name] |
|
416
|
|
|
delattr(self, image_name) |
|
417
|
|
|
|
|
418
|
|
|
def plot(self, **kwargs) -> None: |
|
419
|
|
|
"""Plot images using matplotlib. |
|
420
|
|
|
|
|
421
|
|
|
Args: |
|
422
|
|
|
**kwargs: Keyword arguments that will be passed on to |
|
423
|
|
|
:meth:`~torchio.Image.plot`. |
|
424
|
|
|
""" |
|
425
|
|
|
from ..visualization import plot_subject # avoid circular import |
|
426
|
|
|
|
|
427
|
|
|
plot_subject(self, **kwargs) |
|
428
|
|
|
|