1
|
|
|
import copy |
2
|
|
|
import pprint |
3
|
|
|
from typing import Any, Dict, List, Tuple, Optional, Sequence |
4
|
|
|
|
5
|
|
|
import numpy as np |
6
|
|
|
|
7
|
|
|
from ..constants import TYPE, INTENSITY |
8
|
|
|
from .image import Image |
9
|
|
|
from ..utils import get_subclasses |
10
|
|
|
|
11
|
|
|
|
12
|
|
|
class Subject(dict): |
13
|
|
|
"""Class to store information about the images corresponding to a subject. |
14
|
|
|
|
15
|
|
|
Args: |
16
|
|
|
*args: If provided, a dictionary of items. |
17
|
|
|
**kwargs: Items that will be added to the subject sample. |
18
|
|
|
|
19
|
|
|
Example: |
20
|
|
|
|
21
|
|
|
>>> import torchio as tio |
22
|
|
|
>>> # One way: |
23
|
|
|
>>> subject = tio.Subject( |
24
|
|
|
... one_image=tio.ScalarImage('path_to_image.nii.gz'), |
25
|
|
|
... a_segmentation=tio.LabelMap('path_to_seg.nii.gz'), |
26
|
|
|
... age=45, |
27
|
|
|
... name='John Doe', |
28
|
|
|
... hospital='Hospital Juan Negrín', |
29
|
|
|
... ) |
30
|
|
|
>>> # If you want to create the mapping before, or have spaces in the keys: |
31
|
|
|
>>> subject_dict = { |
32
|
|
|
... 'one image': tio.ScalarImage('path_to_image.nii.gz'), |
33
|
|
|
... 'a segmentation': tio.LabelMap('path_to_seg.nii.gz'), |
34
|
|
|
... 'age': 45, |
35
|
|
|
... 'name': 'John Doe', |
36
|
|
|
... 'hospital': 'Hospital Juan Negrín', |
37
|
|
|
... } |
38
|
|
|
>>> subject = tio.Subject(subject_dict) |
39
|
|
|
|
40
|
|
|
""" |
41
|
|
|
|
42
|
|
|
def __init__(self, *args, **kwargs: Dict[str, Any]): |
43
|
|
|
if args: |
44
|
|
|
if len(args) == 1 and isinstance(args[0], dict): |
45
|
|
|
kwargs.update(args[0]) |
46
|
|
|
else: |
47
|
|
|
message = ( |
48
|
|
|
'Only one dictionary as positional argument is allowed') |
49
|
|
|
raise ValueError(message) |
50
|
|
|
super().__init__(**kwargs) |
51
|
|
|
self._parse_images(self.get_images(intensity_only=False)) |
52
|
|
|
self.update_attributes() # this allows me to do e.g. subject.t1 |
53
|
|
|
self.applied_transforms = [] |
54
|
|
|
|
55
|
|
|
def __repr__(self): |
56
|
|
|
num_images = len(self.get_images(intensity_only=False)) |
57
|
|
|
string = ( |
58
|
|
|
f'{self.__class__.__name__}' |
59
|
|
|
f'(Keys: {tuple(self.keys())}; images: {num_images})' |
60
|
|
|
) |
61
|
|
|
return string |
62
|
|
|
|
63
|
|
|
def __copy__(self): |
64
|
|
|
result_dict = {} |
65
|
|
|
for key, value in self.items(): |
66
|
|
|
if isinstance(value, Image): |
67
|
|
|
value = copy.copy(value) |
68
|
|
|
else: |
69
|
|
|
value = copy.deepcopy(value) |
70
|
|
|
result_dict[key] = value |
71
|
|
|
new = Subject(result_dict) |
72
|
|
|
new.applied_transforms = self.applied_transforms[:] |
73
|
|
|
return new |
74
|
|
|
|
75
|
|
|
def __len__(self): |
76
|
|
|
return len(self.get_images(intensity_only=False)) |
77
|
|
|
|
78
|
|
|
@staticmethod |
79
|
|
|
def _parse_images(images: List[Tuple[str, Image]]) -> None: |
80
|
|
|
# Check that it's not empty |
81
|
|
|
if not images: |
82
|
|
|
raise ValueError('A subject without images cannot be created') |
83
|
|
|
|
84
|
|
|
@property |
85
|
|
|
def shape(self): |
86
|
|
|
"""Return shape of first image in subject. |
87
|
|
|
|
88
|
|
|
Consistency of shapes across images in the subject is checked first. |
89
|
|
|
""" |
90
|
|
|
self.check_consistent_attribute('shape') |
91
|
|
|
return self.get_first_image().shape |
92
|
|
|
|
93
|
|
|
@property |
94
|
|
|
def spatial_shape(self): |
95
|
|
|
"""Return spatial shape of first image in subject. |
96
|
|
|
|
97
|
|
|
Consistency of spatial shapes across images in the subject is checked |
98
|
|
|
first. |
99
|
|
|
""" |
100
|
|
|
self.check_consistent_spatial_shape() |
101
|
|
|
return self.get_first_image().spatial_shape |
102
|
|
|
|
103
|
|
|
@property |
104
|
|
|
def spacing(self): |
105
|
|
|
"""Return spacing of first image in subject. |
106
|
|
|
|
107
|
|
|
Consistency of spacings across images in the subject is checked first. |
108
|
|
|
""" |
109
|
|
|
self.check_consistent_attribute('spacing') |
110
|
|
|
return self.get_first_image().spacing |
111
|
|
|
|
112
|
|
|
@property |
113
|
|
|
def history(self): |
114
|
|
|
# Kept for backwards compatibility |
115
|
|
|
return self.get_applied_transforms() |
116
|
|
|
|
117
|
|
|
def get_applied_transforms(self, ignore_intensity: bool = False): |
118
|
|
|
from ..transforms.transform import Transform |
119
|
|
|
from ..transforms.intensity_transform import IntensityTransform |
120
|
|
|
name_to_transform = { |
121
|
|
|
cls.__name__: cls |
122
|
|
|
for cls in get_subclasses(Transform) |
123
|
|
|
} |
124
|
|
|
transforms_list = [] |
125
|
|
|
for transform_name, arguments in self.applied_transforms: |
126
|
|
|
transform = name_to_transform[transform_name](**arguments) |
127
|
|
|
if ignore_intensity and isinstance(transform, IntensityTransform): |
128
|
|
|
continue |
129
|
|
|
transforms_list.append(transform) |
130
|
|
|
return transforms_list |
131
|
|
|
|
132
|
|
|
def get_composed_history( |
133
|
|
|
self, |
134
|
|
|
ignore_intensity: bool = False, |
135
|
|
|
) -> 'Transform': |
136
|
|
|
from ..transforms.augmentation.composition import Compose |
137
|
|
|
transforms = self.get_applied_transforms( |
138
|
|
|
ignore_intensity=ignore_intensity) |
139
|
|
|
return Compose(transforms) |
140
|
|
|
|
141
|
|
|
def get_inverse_transform( |
142
|
|
|
self, |
143
|
|
|
warn: bool = True, |
144
|
|
|
ignore_intensity: bool = True, |
145
|
|
|
) -> 'Transform': |
146
|
|
|
history_transform = self.get_composed_history( |
147
|
|
|
ignore_intensity=ignore_intensity) |
148
|
|
|
inverse_transform = history_transform.inverse(warn=warn) |
149
|
|
|
return inverse_transform |
150
|
|
|
|
151
|
|
|
def apply_inverse_transform( |
152
|
|
|
self, |
153
|
|
|
warn: bool = True, |
154
|
|
|
ignore_intensity: bool = True, |
155
|
|
|
) -> 'Subject': |
156
|
|
|
inverse_transform = self.get_inverse_transform( |
157
|
|
|
warn=warn, |
158
|
|
|
ignore_intensity=ignore_intensity, |
159
|
|
|
) |
160
|
|
|
transformed = inverse_transform(self) |
161
|
|
|
transformed.clear_history() |
162
|
|
|
return transformed |
163
|
|
|
|
164
|
|
|
def clear_history(self) -> None: |
165
|
|
|
self.applied_transforms = [] |
166
|
|
|
|
167
|
|
|
def check_consistent_attribute(self, attribute: str) -> None: |
168
|
|
|
values_dict = {} |
169
|
|
|
iterable = self.get_images_dict(intensity_only=False).items() |
170
|
|
|
for image_name, image in iterable: |
171
|
|
|
values_dict[image_name] = getattr(image, attribute) |
172
|
|
|
num_unique_values = len(set(values_dict.values())) |
173
|
|
|
if num_unique_values > 1: |
174
|
|
|
message = ( |
175
|
|
|
f'More than one {attribute} found in subject images:' |
176
|
|
|
f'\n{pprint.pformat(values_dict)}' |
177
|
|
|
) |
178
|
|
|
raise RuntimeError(message) |
179
|
|
|
|
180
|
|
|
def check_consistent_spatial_shape(self) -> None: |
181
|
|
|
self.check_consistent_attribute('spatial_shape') |
182
|
|
|
|
183
|
|
|
def check_consistent_orientation(self) -> None: |
184
|
|
|
self.check_consistent_attribute('orientation') |
185
|
|
|
|
186
|
|
|
def check_consistent_affine(self): |
187
|
|
|
# https://github.com/fepegar/torchio/issues/354 |
188
|
|
|
affine = None |
189
|
|
|
first_image = None |
190
|
|
|
iterable = self.get_images_dict(intensity_only=False).items() |
191
|
|
|
for image_name, image in iterable: |
192
|
|
|
if affine is None: |
193
|
|
|
affine = image.affine |
194
|
|
|
first_image = image_name |
195
|
|
|
elif not np.allclose(affine, image.affine, rtol=1e-6, atol=1e-6): |
196
|
|
|
message = ( |
197
|
|
|
f'Images "{first_image}" and "{image_name}" do not occupy' |
198
|
|
|
' the same physical space.' |
199
|
|
|
f'\nAffine of "{first_image}":' |
200
|
|
|
f'\n{pprint.pformat(affine)}' |
201
|
|
|
f'\nAffine of "{image_name}":' |
202
|
|
|
f'\n{pprint.pformat(image.affine)}' |
203
|
|
|
) |
204
|
|
|
raise RuntimeError(message) |
205
|
|
|
|
206
|
|
|
def check_consistent_space(self): |
207
|
|
|
self.check_consistent_spatial_shape() |
208
|
|
|
self.check_consistent_affine() |
209
|
|
|
|
210
|
|
|
def get_images_dict( |
211
|
|
|
self, |
212
|
|
|
intensity_only=True, |
213
|
|
|
include: Optional[Sequence[str]] = None, |
214
|
|
|
exclude: Optional[Sequence[str]] = None, |
215
|
|
|
) -> Dict[str, Image]: |
216
|
|
|
images = {} |
217
|
|
|
for image_name, image in self.items(): |
218
|
|
|
if not isinstance(image, Image): |
219
|
|
|
continue |
220
|
|
|
if intensity_only and not image[TYPE] == INTENSITY: |
221
|
|
|
continue |
222
|
|
|
if include is not None and image_name not in include: |
223
|
|
|
continue |
224
|
|
|
if exclude is not None and image_name in exclude: |
225
|
|
|
continue |
226
|
|
|
images[image_name] = image |
227
|
|
|
return images |
228
|
|
|
|
229
|
|
|
def get_images( |
230
|
|
|
self, |
231
|
|
|
intensity_only=True, |
232
|
|
|
include: Optional[Sequence[str]] = None, |
233
|
|
|
exclude: Optional[Sequence[str]] = None, |
234
|
|
|
) -> List[Image]: |
235
|
|
|
images_dict = self.get_images_dict(intensity_only=intensity_only, include=include, exclude=exclude) |
236
|
|
|
return list(images_dict.values()) |
237
|
|
|
|
238
|
|
|
def get_first_image(self) -> Image: |
239
|
|
|
return self.get_images(intensity_only=False)[0] |
240
|
|
|
|
241
|
|
|
# flake8: noqa: F821 |
242
|
|
|
def add_transform( |
243
|
|
|
self, |
244
|
|
|
transform: 'Transform', |
245
|
|
|
parameters_dict: dict, |
246
|
|
|
) -> None: |
247
|
|
|
self.applied_transforms.append((transform.name, parameters_dict)) |
248
|
|
|
|
249
|
|
|
def load(self) -> None: |
250
|
|
|
"""Load images in subject.""" |
251
|
|
|
for image in self.get_images(intensity_only=False): |
252
|
|
|
image.load() |
253
|
|
|
|
254
|
|
|
def update_attributes(self) -> None: |
255
|
|
|
# This allows to get images using attribute notation, e.g. subject.t1 |
256
|
|
|
self.__dict__.update(self) |
257
|
|
|
|
258
|
|
|
def add_image(self, image: Image, image_name: str) -> None: |
259
|
|
|
"""Add an image.""" |
260
|
|
|
self[image_name] = image |
261
|
|
|
self.update_attributes() |
262
|
|
|
|
263
|
|
|
def remove_image(self, image_name: str) -> None: |
264
|
|
|
"""Remove an image.""" |
265
|
|
|
del self[image_name] |
266
|
|
|
|
267
|
|
|
def plot(self, **kwargs) -> None: |
268
|
|
|
"""Plot images.""" |
269
|
|
|
from ..visualization import plot_subject # avoid circular import |
270
|
|
|
plot_subject(self, **kwargs) |
271
|
|
|
|