torchio.data.subject.Subject.shape()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 16
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import copy
2
import pprint
3
from typing import Any, Dict, List, Tuple, Optional, Sequence, TYPE_CHECKING
4
5
import numpy as np
6
7
from ..constants import TYPE, INTENSITY
8
from .image import Image
9
from ..utils import get_subclasses
10
11
if TYPE_CHECKING:
12
    from ..transforms import Transform, Compose
13
14
15
class Subject(dict):
16
    """Class to store information about the images corresponding to a subject.
17
18
    Args:
19
        *args: If provided, a dictionary of items.
20
        **kwargs: Items that will be added to the subject sample.
21
22
    Example:
23
24
        >>> import torchio as tio
25
        >>> # One way:
26
        >>> subject = tio.Subject(
27
        ...     one_image=tio.ScalarImage('path_to_image.nii.gz'),
28
        ...     a_segmentation=tio.LabelMap('path_to_seg.nii.gz'),
29
        ...     age=45,
30
        ...     name='John Doe',
31
        ...     hospital='Hospital Juan Negrín',
32
        ... )
33
        >>> # If you want to create the mapping before, or have spaces in the keys:
34
        >>> subject_dict = {
35
        ...     'one image': tio.ScalarImage('path_to_image.nii.gz'),
36
        ...     'a segmentation': tio.LabelMap('path_to_seg.nii.gz'),
37
        ...     'age': 45,
38
        ...     'name': 'John Doe',
39
        ...     'hospital': 'Hospital Juan Negrín',
40
        ... }
41
        >>> subject = tio.Subject(subject_dict)
42
43
    """
44
45
    def __init__(self, *args, **kwargs: Dict[str, Any]):
46
        if args:
47
            if len(args) == 1 and isinstance(args[0], dict):
48
                kwargs.update(args[0])
49
            else:
50
                message = (
51
                    'Only one dictionary as positional argument is allowed')
52
                raise ValueError(message)
53
        super().__init__(**kwargs)
54
        self._parse_images(self.get_images(intensity_only=False))
55
        self.update_attributes()  # this allows me to do e.g. subject.t1
56
        self.applied_transforms = []
57
58
    def __repr__(self):
59
        num_images = len(self.get_images(intensity_only=False))
60
        string = (
61
            f'{self.__class__.__name__}'
62
            f'(Keys: {tuple(self.keys())}; images: {num_images})'
63
        )
64
        return string
65
66
    def __copy__(self):
67
        result_dict = {}
68
        for key, value in self.items():
69
            if isinstance(value, Image):
70
                value = copy.copy(value)
71
            else:
72
                value = copy.deepcopy(value)
73
            result_dict[key] = value
74
        new = Subject(result_dict)
75
        new.applied_transforms = self.applied_transforms[:]
76
        return new
77
78
    def __len__(self):
79
        return len(self.get_images(intensity_only=False))
80
81
    @staticmethod
82
    def _parse_images(images: List[Tuple[str, Image]]) -> None:
83
        # Check that it's not empty
84
        if not images:
85
            raise ValueError('A subject without images cannot be created')
86
87
    @property
88
    def shape(self):
89
        """Return shape of first image in subject.
90
91
        Consistency of shapes across images in the subject is checked first.
92
93
        Example::
94
95
            >>> import torchio as tio
96
            >>> colin = tio.datasets.Colin27()
97
            >>> colin.shape
98
            (1, 181, 217, 181)
99
100
        """
101
        self.check_consistent_attribute('shape')
102
        return self.get_first_image().shape
103
104
    @property
105
    def spatial_shape(self):
106
        """Return spatial shape of first image in subject.
107
108
        Consistency of spatial shapes across images in the subject is checked
109
        first.
110
111
        Example::
112
113
            >>> import torchio as tio
114
            >>> colin = tio.datasets.Colin27()
115
            >>> colin.shape
116
            (181, 217, 181)
117
        """
118
        self.check_consistent_spatial_shape()
119
        return self.get_first_image().spatial_shape
120
121
    @property
122
    def spacing(self):
123
        """Return spacing of first image in subject.
124
125
        Consistency of spacings across images in the subject is checked first.
126
127
        Example::
128
129
            >>> import torchio as tio
130
            >>> colin = tio.datasets.Slicer()
131
            >>> colin.shape
132
            (1.0, 1.0, 1.2999954223632812)
133
        """
134
        self.check_consistent_attribute('spacing')
135
        return self.get_first_image().spacing
136
137
    @property
138
    def history(self):
139
        # Kept for backwards compatibility
140
        return self.get_applied_transforms()
141
142
    def is_2d(self):
143
        return all(i.is_2d() for i in self.get_images(intensity_only=False))
144
145
    def get_applied_transforms(
146
            self,
147
            ignore_intensity: bool = False,
148
            image_interpolation: Optional[str] = None,
149
            ) -> List['Transform']:
150
        from ..transforms.transform import Transform
151
        from ..transforms.intensity_transform import IntensityTransform
152
        name_to_transform = {
153
            cls.__name__: cls
154
            for cls in get_subclasses(Transform)
155
        }
156
        transforms_list = []
157
        for transform_name, arguments in self.applied_transforms:
158
            transform = name_to_transform[transform_name](**arguments)
159
            if ignore_intensity and isinstance(transform, IntensityTransform):
160
                continue
161
            resamples = hasattr(transform, 'image_interpolation')
162
            if resamples and image_interpolation is not None:
163
                parsed = transform.parse_interpolation(image_interpolation)
164
                transform.image_interpolation = parsed
165
            transforms_list.append(transform)
166
        return transforms_list
167
168
    def get_composed_history(
169
            self,
170
            ignore_intensity: bool = False,
171
            image_interpolation: Optional[str] = None,
172
            ) -> 'Compose':
173
        from ..transforms.augmentation.composition import Compose
174
        transforms = self.get_applied_transforms(
175
            ignore_intensity=ignore_intensity,
176
            image_interpolation=image_interpolation,
177
        )
178
        return Compose(transforms)
179
180
    def get_inverse_transform(
181
            self,
182
            warn: bool = True,
183
            ignore_intensity: bool = True,
184
            image_interpolation: Optional[str] = None,
185
            ) ->  'Compose':
186
        """Get a reversed list of the inverses of the applied transforms.
187
188
        Args:
189
            warn: Issue a warning if some transforms are not invertible.
190
            ignore_intensity: If ``True``, all instances of
191
                :class:`~torchio.transforms.intensity_transform.IntensityTransform`
192
                will be ignored.
193
            image_interpolation: Modify interpolation for scalar images inside
194
                transforms that perform resampling.
195
        """
196
        history_transform = self.get_composed_history(
197
            ignore_intensity=ignore_intensity,
198
            image_interpolation=image_interpolation,
199
        )
200
        inverse_transform = history_transform.inverse(warn=warn)
201
        return inverse_transform
202
203
    def apply_inverse_transform(self, **kwargs) -> 'Subject':
204
        """Try to apply the inverse of all applied transforms, in reverse order.
205
206
        Args:
207
            **kwargs: Keyword arguments passed on to
208
                :meth:`~torchio.data.subject.Subject.get_inverse_transform`.
209
        """
210
        inverse_transform = self.get_inverse_transform(**kwargs)
211
        transformed = inverse_transform(self)
212
        transformed.clear_history()
213
        return transformed
214
215
    def clear_history(self) -> None:
216
        self.applied_transforms = []
217
218
    def check_consistent_attribute(
219
            self,
220
            attribute: str,
221
            relative_tolerance: float = 1e-6,
222
            absolute_tolerance: float = 1e-6,
223
            message: Optional[str] = None,
224
            ) -> None:
225
        r"""Check for consistency of an attribute across all images.
226
227
        Args:
228
            attribute: Name of the image attribute to check
229
            relative_tolerance: Relative tolerance for :func:`numpy.allclose()`
230
            absolute_tolerance: Absolute tolerance for :func:`numpy.allclose()`
231
232
        Example:
233
            >>> import numpy as np
234
            >>> import torch
235
            >>> import torchio as tio
236
            >>> scalars = torch.randn(1, 512, 512, 100)
237
            >>> mask = torch.tensor(scalars > 0).type(torch.int16)
238
            >>> af1 = np.eye([0.8, 0.8, 2.50000000000001, 1])
239
            >>> af2 = np.eye([0.8, 0.8, 2.49999999999999, 1])  # small difference here (e.g. due to different reader)
240
            >>> subject = tio.Subject(
241
            ...   image = tio.ScalarImage(tensor=scalars, affine=af1),
242
            ...   mask = tio.LabelMap(tensor=mask, affine=af2)
243
            ... )
244
            >>> subject.check_consistent_attribute('spacing')  # no error as tolerances are > 0
245
246
        .. note:: To check that all values for a specific attribute are close
247
            between all images in the subject, :func:`numpy.allclose()` is used.
248
            This function returns ``True`` if
249
            :math:`|a_i - b_i| \leq t_{abs} + t_{rel} * |b_i|`, where
250
            :math:`a_i` and :math:`b_i` are the :math:`i`-th element of the same
251
            attribute of two images being compared,
252
            :math:`t_{abs}` is the ``absolute_tolerance`` and
253
            :math:`t_{rel}` is the ``relative_tolerance``.
254
        """
255
        message = (
256
            f'More than one value for "{attribute}" found in subject images:'
257
            '\n{}'
258
        )
259
260
        names_images = self.get_images_dict(intensity_only=False).items()
261
        try:
262
            first_attribute = None
263
            first_image = None
264
265
            for image_name, image in names_images:
266
                if first_attribute is None:
267
                    first_attribute = getattr(image, attribute)
268
                    first_image = image_name
269
                    continue
270
                current_attribute = getattr(image, attribute)
271
                all_close = np.allclose(
272
                    current_attribute,
273
                    first_attribute,
274
                    rtol=relative_tolerance,
275
                    atol=absolute_tolerance,
276
                )
277
                if not all_close:
278
                    message = message.format(
279
                        pprint.pformat({
280
                            first_image: first_attribute,
281
                            image_name: current_attribute
282
                        }),
283
                    )
284
                    raise RuntimeError(message)
285
        except TypeError:
286
            # fallback for non-numeric values
287
            values_dict = {}
288
            for image_name, image in names_images:
289
                values_dict[image_name] = getattr(image, attribute)
290
            num_unique_values = len(set(values_dict.values()))
291
            if num_unique_values > 1:
292
                message = message.format(pprint.pformat(values_dict))
293
                raise RuntimeError(message) from None
294
295
    def check_consistent_spatial_shape(self) -> None:
296
        self.check_consistent_attribute('spatial_shape')
297
298
    def check_consistent_orientation(self) -> None:
299
        self.check_consistent_attribute('orientation')
300
301
    def check_consistent_affine(self) -> None:
302
        self.check_consistent_attribute('affine')
303
304
    def check_consistent_space(self) -> None:
305
        try:
306
            self.check_consistent_attribute('spacing')
307
            self.check_consistent_attribute('direction')
308
            self.check_consistent_attribute('origin')
309
            self.check_consistent_spatial_shape()
310
        except RuntimeError as e:
311
            message = (
312
                'As described above, some images in the subject are not in the'
313
                ' same space. You probably can use the transforms ToCanonical'
314
                ' and Resample to fix this, as explained at'
315
                ' https://github.com/fepegar/torchio/issues/647#issuecomment-913025695'
316
            )
317
            raise RuntimeError(message) from e
318
319
    def get_images_names(self) -> List[str]:
320
        return list(self.get_images_dict(intensity_only=False).keys())
321
322
    def get_images_dict(
323
            self,
324
            intensity_only=True,
325
            include: Optional[Sequence[str]] = None,
326
            exclude: Optional[Sequence[str]] = None,
327
            ) -> Dict[str, Image]:
328
        images = {}
329
        for image_name, image in self.items():
330
            if not isinstance(image, Image):
331
                continue
332
            if intensity_only and not image[TYPE] == INTENSITY:
333
                continue
334
            if include is not None and image_name not in include:
335
                continue
336
            if exclude is not None and image_name in exclude:
337
                continue
338
            images[image_name] = image
339
        return images
340
341
    def get_images(
342
            self,
343
            intensity_only=True,
344
            include: Optional[Sequence[str]] = None,
345
            exclude: Optional[Sequence[str]] = None,
346
            ) -> List[Image]:
347
        images_dict = self.get_images_dict(
348
            intensity_only=intensity_only,
349
            include=include,
350
            exclude=exclude,
351
        )
352
        return list(images_dict.values())
353
354
    def get_first_image(self) -> Image:
355
        return self.get_images(intensity_only=False)[0]
356
357
    # flake8: noqa: F821
358
    def add_transform(
359
            self,
360
            transform: 'Transform',
361
            parameters_dict: dict,
362
            ) -> None:
363
        self.applied_transforms.append((transform.name, parameters_dict))
364
365
    def load(self) -> None:
366
        """Load images in subject on RAM."""
367
        for image in self.get_images(intensity_only=False):
368
            image.load()
369
370
    def update_attributes(self) -> None:
371
        # This allows to get images using attribute notation, e.g. subject.t1
372
        self.__dict__.update(self)
373
374
    def add_image(self, image: Image, image_name: str) -> None:
375
        """Add an image."""
376
        self[image_name] = image
377
        self.update_attributes()
378
379
    def remove_image(self, image_name: str) -> None:
380
        """Remove an image."""
381
        del self[image_name]
382
        delattr(self, image_name)
383
384
    def plot(self, **kwargs) -> None:
385
        """Plot images using matplotlib.
386
387
        Args:
388
            **kwargs: Keyword arguments that will be passed on to
389
                :meth:`~torchio.Image.plot`.
390
        """
391
        from ..visualization import plot_subject  # avoid circular import
392
        plot_subject(self, **kwargs)
393