1
|
|
|
import copy |
2
|
|
|
import pprint |
3
|
|
|
from typing import Any, Dict, List, Tuple, Optional, Sequence, TYPE_CHECKING |
4
|
|
|
|
5
|
|
|
import numpy as np |
6
|
|
|
|
7
|
|
|
from ..constants import TYPE, INTENSITY |
8
|
|
|
from .image import Image |
9
|
|
|
from ..utils import get_subclasses |
10
|
|
|
|
11
|
|
|
if TYPE_CHECKING: |
12
|
|
|
from ..transforms import Transform, Compose |
13
|
|
|
|
14
|
|
|
|
15
|
|
|
class Subject(dict): |
16
|
|
|
"""Class to store information about the images corresponding to a subject. |
17
|
|
|
|
18
|
|
|
Args: |
19
|
|
|
*args: If provided, a dictionary of items. |
20
|
|
|
**kwargs: Items that will be added to the subject sample. |
21
|
|
|
|
22
|
|
|
Example: |
23
|
|
|
|
24
|
|
|
>>> import torchio as tio |
25
|
|
|
>>> # One way: |
26
|
|
|
>>> subject = tio.Subject( |
27
|
|
|
... one_image=tio.ScalarImage('path_to_image.nii.gz'), |
28
|
|
|
... a_segmentation=tio.LabelMap('path_to_seg.nii.gz'), |
29
|
|
|
... age=45, |
30
|
|
|
... name='John Doe', |
31
|
|
|
... hospital='Hospital Juan Negrín', |
32
|
|
|
... ) |
33
|
|
|
>>> # If you want to create the mapping before, or have spaces in the keys: |
34
|
|
|
>>> subject_dict = { |
35
|
|
|
... 'one image': tio.ScalarImage('path_to_image.nii.gz'), |
36
|
|
|
... 'a segmentation': tio.LabelMap('path_to_seg.nii.gz'), |
37
|
|
|
... 'age': 45, |
38
|
|
|
... 'name': 'John Doe', |
39
|
|
|
... 'hospital': 'Hospital Juan Negrín', |
40
|
|
|
... } |
41
|
|
|
>>> subject = tio.Subject(subject_dict) |
42
|
|
|
|
43
|
|
|
""" |
44
|
|
|
|
45
|
|
|
def __init__(self, *args, **kwargs: Dict[str, Any]): |
46
|
|
|
if args: |
47
|
|
|
if len(args) == 1 and isinstance(args[0], dict): |
48
|
|
|
kwargs.update(args[0]) |
49
|
|
|
else: |
50
|
|
|
message = ( |
51
|
|
|
'Only one dictionary as positional argument is allowed') |
52
|
|
|
raise ValueError(message) |
53
|
|
|
super().__init__(**kwargs) |
54
|
|
|
self._parse_images(self.get_images(intensity_only=False)) |
55
|
|
|
self.update_attributes() # this allows me to do e.g. subject.t1 |
56
|
|
|
self.applied_transforms = [] |
57
|
|
|
|
58
|
|
|
def __repr__(self): |
59
|
|
|
num_images = len(self.get_images(intensity_only=False)) |
60
|
|
|
string = ( |
61
|
|
|
f'{self.__class__.__name__}' |
62
|
|
|
f'(Keys: {tuple(self.keys())}; images: {num_images})' |
63
|
|
|
) |
64
|
|
|
return string |
65
|
|
|
|
66
|
|
|
def __copy__(self): |
67
|
|
|
result_dict = {} |
68
|
|
|
for key, value in self.items(): |
69
|
|
|
if isinstance(value, Image): |
70
|
|
|
value = copy.copy(value) |
71
|
|
|
else: |
72
|
|
|
value = copy.deepcopy(value) |
73
|
|
|
result_dict[key] = value |
74
|
|
|
new = Subject(result_dict) |
75
|
|
|
new.applied_transforms = self.applied_transforms[:] |
76
|
|
|
return new |
77
|
|
|
|
78
|
|
|
def __len__(self): |
79
|
|
|
return len(self.get_images(intensity_only=False)) |
80
|
|
|
|
81
|
|
|
@staticmethod |
82
|
|
|
def _parse_images(images: List[Tuple[str, Image]]) -> None: |
83
|
|
|
# Check that it's not empty |
84
|
|
|
if not images: |
85
|
|
|
raise ValueError('A subject without images cannot be created') |
86
|
|
|
|
87
|
|
|
@property |
88
|
|
|
def shape(self): |
89
|
|
|
"""Return shape of first image in subject. |
90
|
|
|
|
91
|
|
|
Consistency of shapes across images in the subject is checked first. |
92
|
|
|
""" |
93
|
|
|
self.check_consistent_attribute('shape') |
94
|
|
|
return self.get_first_image().shape |
95
|
|
|
|
96
|
|
|
@property |
97
|
|
|
def spatial_shape(self): |
98
|
|
|
"""Return spatial shape of first image in subject. |
99
|
|
|
|
100
|
|
|
Consistency of spatial shapes across images in the subject is checked |
101
|
|
|
first. |
102
|
|
|
""" |
103
|
|
|
self.check_consistent_spatial_shape() |
104
|
|
|
return self.get_first_image().spatial_shape |
105
|
|
|
|
106
|
|
|
@property |
107
|
|
|
def spacing(self): |
108
|
|
|
"""Return spacing of first image in subject. |
109
|
|
|
|
110
|
|
|
Consistency of spacings across images in the subject is checked first. |
111
|
|
|
""" |
112
|
|
|
self.check_consistent_attribute('spacing') |
113
|
|
|
return self.get_first_image().spacing |
114
|
|
|
|
115
|
|
|
@property |
116
|
|
|
def history(self): |
117
|
|
|
# Kept for backwards compatibility |
118
|
|
|
return self.get_applied_transforms() |
119
|
|
|
|
120
|
|
|
def is_2d(self): |
121
|
|
|
return all(i.is_2d() for i in self.get_images(intensity_only=False)) |
122
|
|
|
|
123
|
|
|
def get_applied_transforms( |
124
|
|
|
self, |
125
|
|
|
ignore_intensity: bool = False, |
126
|
|
|
image_interpolation: Optional[str] = None, |
127
|
|
|
) -> List['Transform']: |
128
|
|
|
from ..transforms.transform import Transform |
129
|
|
|
from ..transforms.intensity_transform import IntensityTransform |
130
|
|
|
name_to_transform = { |
131
|
|
|
cls.__name__: cls |
132
|
|
|
for cls in get_subclasses(Transform) |
133
|
|
|
} |
134
|
|
|
transforms_list = [] |
135
|
|
|
for transform_name, arguments in self.applied_transforms: |
136
|
|
|
transform = name_to_transform[transform_name](**arguments) |
137
|
|
|
if ignore_intensity and isinstance(transform, IntensityTransform): |
138
|
|
|
continue |
139
|
|
|
resamples = hasattr(transform, 'image_interpolation') |
140
|
|
|
if resamples and image_interpolation is not None: |
141
|
|
|
parsed = transform.parse_interpolation(image_interpolation) |
142
|
|
|
transform.image_interpolation = parsed |
143
|
|
|
transforms_list.append(transform) |
144
|
|
|
return transforms_list |
145
|
|
|
|
146
|
|
|
def get_composed_history( |
147
|
|
|
self, |
148
|
|
|
ignore_intensity: bool = False, |
149
|
|
|
image_interpolation: Optional[str] = None, |
150
|
|
|
) -> 'Compose': |
151
|
|
|
from ..transforms.augmentation.composition import Compose |
152
|
|
|
transforms = self.get_applied_transforms( |
153
|
|
|
ignore_intensity=ignore_intensity, |
154
|
|
|
image_interpolation=image_interpolation, |
155
|
|
|
) |
156
|
|
|
return Compose(transforms) |
157
|
|
|
|
158
|
|
|
def get_inverse_transform( |
159
|
|
|
self, |
160
|
|
|
warn: bool = True, |
161
|
|
|
ignore_intensity: bool = True, |
162
|
|
|
image_interpolation: Optional[str] = None, |
163
|
|
|
) -> 'Compose': |
164
|
|
|
"""Get a reversed list of the inverses of the applied transforms. |
165
|
|
|
|
166
|
|
|
Args: |
167
|
|
|
warn: Issue a warning if some transforms are not invertible. |
168
|
|
|
ignore_intensity: If ``True``, all instances of |
169
|
|
|
:class:`~torchio.transforms.intensity_transform.IntensityTransform` |
170
|
|
|
will be ignored. |
171
|
|
|
image_interpolation: Modify interpolation for scalar images inside |
172
|
|
|
transforms that perform resampling. |
173
|
|
|
""" |
174
|
|
|
history_transform = self.get_composed_history( |
175
|
|
|
ignore_intensity=ignore_intensity, |
176
|
|
|
image_interpolation=image_interpolation, |
177
|
|
|
) |
178
|
|
|
inverse_transform = history_transform.inverse(warn=warn) |
179
|
|
|
return inverse_transform |
180
|
|
|
|
181
|
|
|
def apply_inverse_transform(self, **kwargs) -> 'Subject': |
182
|
|
|
"""Try to apply the inverse of all applied transforms, in reverse order. |
183
|
|
|
|
184
|
|
|
Args: |
185
|
|
|
**kwargs: Keyword arguments passed on to |
186
|
|
|
:meth:`~torchio.data.subject.Subject.get_inverse_transform`. |
187
|
|
|
""" |
188
|
|
|
inverse_transform = self.get_inverse_transform(**kwargs) |
189
|
|
|
transformed = inverse_transform(self) |
190
|
|
|
transformed.clear_history() |
191
|
|
|
return transformed |
192
|
|
|
|
193
|
|
|
def clear_history(self) -> None: |
194
|
|
|
self.applied_transforms = [] |
195
|
|
|
|
196
|
|
|
def check_consistent_attribute(self, attribute: str) -> None: |
197
|
|
|
values_dict = {} |
198
|
|
|
iterable = self.get_images_dict(intensity_only=False).items() |
199
|
|
|
for image_name, image in iterable: |
200
|
|
|
values_dict[image_name] = getattr(image, attribute) |
201
|
|
|
num_unique_values = len(set(values_dict.values())) |
202
|
|
|
if num_unique_values > 1: |
203
|
|
|
message = ( |
204
|
|
|
f'More than one {attribute} found in subject images:' |
205
|
|
|
f'\n{pprint.pformat(values_dict)}' |
206
|
|
|
) |
207
|
|
|
raise RuntimeError(message) |
208
|
|
|
|
209
|
|
|
def check_consistent_spatial_shape(self) -> None: |
210
|
|
|
self.check_consistent_attribute('spatial_shape') |
211
|
|
|
|
212
|
|
|
def check_consistent_orientation(self) -> None: |
213
|
|
|
self.check_consistent_attribute('orientation') |
214
|
|
|
|
215
|
|
|
def check_consistent_affine(self): |
216
|
|
|
# https://github.com/fepegar/torchio/issues/354 |
217
|
|
|
affine = None |
218
|
|
|
first_image = None |
219
|
|
|
iterable = self.get_images_dict(intensity_only=False).items() |
220
|
|
|
for image_name, image in iterable: |
221
|
|
|
if affine is None: |
222
|
|
|
affine = image.affine |
223
|
|
|
first_image = image_name |
224
|
|
|
elif not np.allclose(affine, image.affine, rtol=1e-6, atol=1e-6): |
225
|
|
|
message = ( |
226
|
|
|
f'Images "{first_image}" and "{image_name}" do not occupy' |
227
|
|
|
' the same physical space.' |
228
|
|
|
f'\nAffine of "{first_image}":' |
229
|
|
|
f'\n{pprint.pformat(affine)}' |
230
|
|
|
f'\nAffine of "{image_name}":' |
231
|
|
|
f'\n{pprint.pformat(image.affine)}' |
232
|
|
|
) |
233
|
|
|
raise RuntimeError(message) |
234
|
|
|
|
235
|
|
|
def check_consistent_space(self): |
236
|
|
|
self.check_consistent_spatial_shape() |
237
|
|
|
self.check_consistent_affine() |
238
|
|
|
|
239
|
|
|
def get_images_dict( |
240
|
|
|
self, |
241
|
|
|
intensity_only=True, |
242
|
|
|
include: Optional[Sequence[str]] = None, |
243
|
|
|
exclude: Optional[Sequence[str]] = None, |
244
|
|
|
) -> Dict[str, Image]: |
245
|
|
|
images = {} |
246
|
|
|
for image_name, image in self.items(): |
247
|
|
|
if not isinstance(image, Image): |
248
|
|
|
continue |
249
|
|
|
if intensity_only and not image[TYPE] == INTENSITY: |
250
|
|
|
continue |
251
|
|
|
if include is not None and image_name not in include: |
252
|
|
|
continue |
253
|
|
|
if exclude is not None and image_name in exclude: |
254
|
|
|
continue |
255
|
|
|
images[image_name] = image |
256
|
|
|
return images |
257
|
|
|
|
258
|
|
|
def get_images( |
259
|
|
|
self, |
260
|
|
|
intensity_only=True, |
261
|
|
|
include: Optional[Sequence[str]] = None, |
262
|
|
|
exclude: Optional[Sequence[str]] = None, |
263
|
|
|
) -> List[Image]: |
264
|
|
|
images_dict = self.get_images_dict( |
265
|
|
|
intensity_only=intensity_only, |
266
|
|
|
include=include, |
267
|
|
|
exclude=exclude, |
268
|
|
|
) |
269
|
|
|
return list(images_dict.values()) |
270
|
|
|
|
271
|
|
|
def get_first_image(self) -> Image: |
272
|
|
|
return self.get_images(intensity_only=False)[0] |
273
|
|
|
|
274
|
|
|
# flake8: noqa: F821 |
275
|
|
|
def add_transform( |
276
|
|
|
self, |
277
|
|
|
transform: 'Transform', |
278
|
|
|
parameters_dict: dict, |
279
|
|
|
) -> None: |
280
|
|
|
self.applied_transforms.append((transform.name, parameters_dict)) |
281
|
|
|
|
282
|
|
|
def load(self) -> None: |
283
|
|
|
"""Load images in subject on RAM.""" |
284
|
|
|
for image in self.get_images(intensity_only=False): |
285
|
|
|
image.load() |
286
|
|
|
|
287
|
|
|
def update_attributes(self) -> None: |
288
|
|
|
# This allows to get images using attribute notation, e.g. subject.t1 |
289
|
|
|
self.__dict__.update(self) |
290
|
|
|
|
291
|
|
|
def add_image(self, image: Image, image_name: str) -> None: |
292
|
|
|
"""Add an image.""" |
293
|
|
|
self[image_name] = image |
294
|
|
|
self.update_attributes() |
295
|
|
|
|
296
|
|
|
def remove_image(self, image_name: str) -> None: |
297
|
|
|
"""Remove an image.""" |
298
|
|
|
del self[image_name] |
299
|
|
|
delattr(self, image_name) |
300
|
|
|
|
301
|
|
|
def plot(self, **kwargs) -> None: |
302
|
|
|
"""Plot images using matplotlib. |
303
|
|
|
|
304
|
|
|
Args: |
305
|
|
|
**kwargs: Keyword arguments that will be passed on to |
306
|
|
|
:class:`~torchio.data.image.Image`. |
307
|
|
|
""" |
308
|
|
|
from ..visualization import plot_subject # avoid circular import |
309
|
|
|
plot_subject(self, **kwargs) |
310
|
|
|
|