torchio.data.subject.Subject.plot()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
from __future__ import annotations
2
3
import copy
4
import pprint
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING
7
from typing import Any
8
9
import numpy as np
10
11
from ..constants import INTENSITY
12
from ..constants import TYPE
13
from ..utils import get_subclasses
14
from .image import Image
15
16
if TYPE_CHECKING:
17
    from ..transforms import Compose
18
    from ..transforms import Transform
19
20
21
class Subject(dict):
22
    """Class to store information about the images corresponding to a subject.
23
24
    Args:
25
        *args: If provided, a dictionary of items.
26
        **kwargs: Items that will be added to the subject sample.
27
28
    Example:
29
30
        >>> import torchio as tio
31
        >>> # One way:
32
        >>> subject = tio.Subject(
33
        ...     one_image=tio.ScalarImage('path_to_image.nii.gz'),
34
        ...     a_segmentation=tio.LabelMap('path_to_seg.nii.gz'),
35
        ...     age=45,
36
        ...     name='John Doe',
37
        ...     hospital='Hospital Juan Negrín',
38
        ... )
39
        >>> # If you want to create the mapping before, or have spaces in the keys:
40
        >>> subject_dict = {
41
        ...     'one image': tio.ScalarImage('path_to_image.nii.gz'),
42
        ...     'a segmentation': tio.LabelMap('path_to_seg.nii.gz'),
43
        ...     'age': 45,
44
        ...     'name': 'John Doe',
45
        ...     'hospital': 'Hospital Juan Negrín',
46
        ... }
47
        >>> subject = tio.Subject(subject_dict)
48
    """
49
50
    def __init__(self, *args, **kwargs: dict[str, Any]):
51
        if args:
52
            if len(args) == 1 and isinstance(args[0], dict):
53
                kwargs.update(args[0])
54
            else:
55
                message = 'Only one dictionary as positional argument is allowed'
56
                raise ValueError(message)
57
        super().__init__(**kwargs)
58
        self._parse_images(self.get_images(intensity_only=False))
59
        self.update_attributes()  # this allows me to do e.g. subject.t1
60
        self.applied_transforms: list[tuple[str, dict]] = []
61
62
    def __repr__(self):
63
        num_images = len(self.get_images(intensity_only=False))
64
        string = (
65
            f'{self.__class__.__name__}'
66
            f'(Keys: {tuple(self.keys())}; images: {num_images})'
67
        )
68
        return string
69
70
    def __len__(self):
71
        return len(self.get_images(intensity_only=False))
72
73
    def __getitem__(self, item):
74
        if isinstance(item, (slice, int, tuple)):
75
            try:
76
                self.check_consistent_spatial_shape()
77
            except RuntimeError as e:
78
                message = (
79
                    'To use indexing, all images in the subject must have the'
80
                    ' same spatial shape'
81
                )
82
                raise RuntimeError(message) from e
83
            copied = copy.deepcopy(self)
84
            for image_name, image in copied.items():
85
                copied[image_name] = image[item]
86
            return copied
87
        else:
88
            return super().__getitem__(item)
89
90
    @staticmethod
91
    def _parse_images(images: list[Image]) -> None:
92
        # Check that it's not empty
93
        if not images:
94
            raise TypeError('A subject without images cannot be created')
95
96
    @property
97
    def shape(self):
98
        """Return shape of first image in subject.
99
100
        Consistency of shapes across images in the subject is checked first.
101
102
        Example:
103
104
            >>> import torchio as tio
105
            >>> colin = tio.datasets.Colin27()
106
            >>> colin.shape
107
            (1, 181, 217, 181)
108
        """
109
        self.check_consistent_attribute('shape')
110
        return self.get_first_image().shape
111
112
    @property
113
    def spatial_shape(self):
114
        """Return spatial shape of first image in subject.
115
116
        Consistency of spatial shapes across images in the subject is checked
117
        first.
118
119
        Example:
120
121
            >>> import torchio as tio
122
            >>> colin = tio.datasets.Colin27()
123
            >>> colin.spatial_shape
124
            (181, 217, 181)
125
        """
126
        self.check_consistent_spatial_shape()
127
        return self.get_first_image().spatial_shape
128
129
    @property
130
    def spacing(self):
131
        """Return spacing of first image in subject.
132
133
        Consistency of spacings across images in the subject is checked first.
134
135
        Example:
136
137
            >>> import torchio as tio
138
            >>> colin = tio.datasets.Slicer()
139
            >>> colin.spacing
140
            (1.0, 1.0, 1.2999954223632812)
141
        """
142
        self.check_consistent_attribute('spacing')
143
        return self.get_first_image().spacing
144
145
    @property
146
    def history(self):
147
        # Kept for backwards compatibility
148
        return self.get_applied_transforms()
149
150
    def is_2d(self):
151
        return all(i.is_2d() for i in self.get_images(intensity_only=False))
152
153
    def get_applied_transforms(
154
        self,
155
        ignore_intensity: bool = False,
156
        image_interpolation: str | None = None,
157
    ) -> list[Transform]:
158
        from ..transforms.intensity_transform import IntensityTransform
159
        from ..transforms.transform import Transform
160
161
        name_to_transform = {cls.__name__: cls for cls in get_subclasses(Transform)}
162
        transforms_list = []
163
        for transform_name, arguments in self.applied_transforms:
164
            transform = name_to_transform[transform_name](**arguments)
165
            if ignore_intensity and isinstance(transform, IntensityTransform):
166
                continue
167
            resamples = hasattr(transform, 'image_interpolation')
168
            if resamples and image_interpolation is not None:
169
                parsed = transform.parse_interpolation(image_interpolation)
170
                transform.image_interpolation = parsed
171
            transforms_list.append(transform)
172
        return transforms_list
173
174
    def get_composed_history(
175
        self,
176
        ignore_intensity: bool = False,
177
        image_interpolation: str | None = None,
178
    ) -> Compose:
179
        from ..transforms.augmentation.composition import Compose
180
181
        transforms = self.get_applied_transforms(
182
            ignore_intensity=ignore_intensity,
183
            image_interpolation=image_interpolation,
184
        )
185
        return Compose(transforms)
186
187
    def get_inverse_transform(
188
        self,
189
        warn: bool = True,
190
        ignore_intensity: bool = False,
191
        image_interpolation: str | None = None,
192
    ) -> Compose:
193
        """Get a reversed list of the inverses of the applied transforms.
194
195
        Args:
196
            warn: Issue a warning if some transforms are not invertible.
197
            ignore_intensity: If ``True``, all instances of
198
                :class:`~torchio.transforms.intensity_transform.IntensityTransform`
199
                will be ignored.
200
            image_interpolation: Modify interpolation for scalar images inside
201
                transforms that perform resampling.
202
        """
203
        history_transform = self.get_composed_history(
204
            ignore_intensity=ignore_intensity,
205
            image_interpolation=image_interpolation,
206
        )
207
        inverse_transform = history_transform.inverse(warn=warn)
208
        return inverse_transform
209
210
    def apply_inverse_transform(self, **kwargs) -> Subject:
211
        """Apply the inverse of all applied transforms, in reverse order.
212
213
        Args:
214
            **kwargs: Keyword arguments passed on to
215
                :meth:`~torchio.data.subject.Subject.get_inverse_transform`.
216
        """
217
        inverse_transform = self.get_inverse_transform(**kwargs)
218
        transformed: Subject
219
        transformed = inverse_transform(self)  # type: ignore[assignment]
220
        transformed.clear_history()
221
        return transformed
222
223
    def clear_history(self) -> None:
224
        self.applied_transforms = []
225
226
    def check_consistent_attribute(
227
        self,
228
        attribute: str,
229
        relative_tolerance: float = 1e-6,
230
        absolute_tolerance: float = 1e-6,
231
        message: str | None = None,
232
    ) -> None:
233
        r"""Check for consistency of an attribute across all images.
234
235
        Args:
236
            attribute: Name of the image attribute to check
237
            relative_tolerance: Relative tolerance for :func:`numpy.allclose()`
238
            absolute_tolerance: Absolute tolerance for :func:`numpy.allclose()`
239
240
        Example:
241
            >>> import numpy as np
242
            >>> import torch
243
            >>> import torchio as tio
244
            >>> scalars = torch.randn(1, 512, 512, 100)
245
            >>> mask = torch.tensor(scalars > 0).type(torch.int16)
246
            >>> af1 = np.eye([0.8, 0.8, 2.50000000000001, 1])
247
            >>> af2 = np.eye([0.8, 0.8, 2.49999999999999, 1])  # small difference here (e.g. due to different reader)
248
            >>> subject = tio.Subject(
249
            ...   image = tio.ScalarImage(tensor=scalars, affine=af1),
250
            ...   mask = tio.LabelMap(tensor=mask, affine=af2)
251
            ... )
252
            >>> subject.check_consistent_attribute('spacing')  # no error as tolerances are > 0
253
254
        .. note:: To check that all values for a specific attribute are close
255
            between all images in the subject, :func:`numpy.allclose()` is used.
256
            This function returns ``True`` if
257
            :math:`|a_i - b_i| \leq t_{abs} + t_{rel} * |b_i|`, where
258
            :math:`a_i` and :math:`b_i` are the :math:`i`-th element of the same
259
            attribute of two images being compared,
260
            :math:`t_{abs}` is the ``absolute_tolerance`` and
261
            :math:`t_{rel}` is the ``relative_tolerance``.
262
        """
263
        message = (
264
            f'More than one value for "{attribute}" found in subject images:\n{{}}'
265
        )
266
267
        names_images = self.get_images_dict(intensity_only=False).items()
268
        try:
269
            first_attribute = None
270
            first_image = None
271
272
            for image_name, image in names_images:
273
                if first_attribute is None:
274
                    first_attribute = getattr(image, attribute)
275
                    first_image = image_name
276
                    continue
277
                current_attribute = getattr(image, attribute)
278
                all_close = np.allclose(
279
                    current_attribute,
280
                    first_attribute,
281
                    rtol=relative_tolerance,
282
                    atol=absolute_tolerance,
283
                )
284
                if not all_close:
285
                    message = message.format(
286
                        pprint.pformat(
287
                            {
288
                                first_image: first_attribute,
289
                                image_name: current_attribute,
290
                            }
291
                        ),
292
                    )
293
                    raise RuntimeError(message)
294
        except TypeError:
295
            # fallback for non-numeric values
296
            values_dict = {}
297
            for image_name, image in names_images:
298
                values_dict[image_name] = getattr(image, attribute)
299
            num_unique_values = len(set(values_dict.values()))
300
            if num_unique_values > 1:
301
                message = message.format(pprint.pformat(values_dict))
302
                raise RuntimeError(message) from None
303
304
    def check_consistent_spatial_shape(self) -> None:
305
        self.check_consistent_attribute('spatial_shape')
306
307
    def check_consistent_orientation(self) -> None:
308
        self.check_consistent_attribute('orientation')
309
310
    def check_consistent_affine(self) -> None:
311
        self.check_consistent_attribute('affine')
312
313
    def check_consistent_space(self) -> None:
314
        try:
315
            self.check_consistent_attribute('spacing')
316
            self.check_consistent_attribute('direction')
317
            self.check_consistent_attribute('origin')
318
            self.check_consistent_spatial_shape()
319
        except RuntimeError as e:
320
            message = (
321
                'As described above, some images in the subject are not in the'
322
                ' same space. You probably can use the transforms ToCanonical'
323
                ' and Resample to fix this, as explained at'
324
                ' https://github.com/TorchIO-project/torchio/issues/647#issuecomment-913025695'
325
            )
326
            raise RuntimeError(message) from e
327
328
    def get_images_names(self) -> list[str]:
329
        return list(self.get_images_dict(intensity_only=False).keys())
330
331
    def get_images_dict(
332
        self,
333
        intensity_only=True,
334
        include: Sequence[str] | None = None,
335
        exclude: Sequence[str] | None = None,
336
    ) -> dict[str, Image]:
337
        images = {}
338
        for image_name, image in self.items():
339
            if not isinstance(image, Image):
340
                continue
341
            if intensity_only and not image[TYPE] == INTENSITY:
342
                continue
343
            if include is not None and image_name not in include:
344
                continue
345
            if exclude is not None and image_name in exclude:
346
                continue
347
            images[image_name] = image
348
        return images
349
350
    def get_images(
351
        self,
352
        intensity_only=True,
353
        include: Sequence[str] | None = None,
354
        exclude: Sequence[str] | None = None,
355
    ) -> list[Image]:
356
        images_dict = self.get_images_dict(
357
            intensity_only=intensity_only,
358
            include=include,
359
            exclude=exclude,
360
        )
361
        return list(images_dict.values())
362
363
    def get_first_image(self) -> Image:
364
        return self.get_images(intensity_only=False)[0]
365
366
    def add_transform(
367
        self,
368
        transform: Transform,
369
        parameters_dict: dict,
370
    ) -> None:
371
        self.applied_transforms.append((transform.name, parameters_dict))
372
373
    def load(self) -> None:
374
        """Load images in subject on RAM."""
375
        for image in self.get_images(intensity_only=False):
376
            image.load()
377
378
    def unload(self) -> None:
379
        """Unload images in subject."""
380
        for image in self.get_images(intensity_only=False):
381
            image.unload()
382
383
    def update_attributes(self) -> None:
384
        # This allows to get images using attribute notation, e.g. subject.t1
385
        self.__dict__.update(self)
386
387
    @staticmethod
388
    def _check_image_name(image_name):
389
        if not isinstance(image_name, str):
390
            message = (
391
                f'The image name must be a string, but it has type "{type(image_name)}"'
392
            )
393
            raise ValueError(message)
394
        return image_name
395
396
    def add_image(self, image: Image, image_name: str) -> None:
397
        """Add an image to the subject instance."""
398
        if not isinstance(image, Image):
399
            message = (
400
                'Image must be an instance of torchio.Image,'
401
                f' but its type is "{type(image)}"'
402
            )
403
            raise ValueError(message)
404
        self._check_image_name(image_name)
405
        self[image_name] = image
406
        self.update_attributes()
407
408
    def remove_image(self, image_name: str) -> None:
409
        """Remove an image from the subject instance."""
410
        self._check_image_name(image_name)
411
        del self[image_name]
412
        delattr(self, image_name)
413
414
    def plot(self, **kwargs) -> None:
415
        """Plot images using matplotlib.
416
417
        Args:
418
            **kwargs: Keyword arguments that will be passed on to
419
                :meth:`~torchio.Image.plot`.
420
        """
421
        from ..visualization import plot_subject  # avoid circular import
422
423
        plot_subject(self, **kwargs)
424