Completed
Push — master ( 17d4df...a2657b )
by Fernando
01:54 queued 41s
created

torchio.data.subject.Subject.is_2d()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 2
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
import copy
2
import pprint
3
from typing import Any, Dict, List, Tuple, Optional, Sequence, TYPE_CHECKING
4
5
import numpy as np
6
7
from ..constants import TYPE, INTENSITY
8
from .image import Image
9
from ..utils import get_subclasses
10
11
if TYPE_CHECKING:
12
    from ..transforms import Transform, Compose
13
14
15
class Subject(dict):
16
    """Class to store information about the images corresponding to a subject.
17
18
    Args:
19
        *args: If provided, a dictionary of items.
20
        **kwargs: Items that will be added to the subject sample.
21
22
    Example:
23
24
        >>> import torchio as tio
25
        >>> # One way:
26
        >>> subject = tio.Subject(
27
        ...     one_image=tio.ScalarImage('path_to_image.nii.gz'),
28
        ...     a_segmentation=tio.LabelMap('path_to_seg.nii.gz'),
29
        ...     age=45,
30
        ...     name='John Doe',
31
        ...     hospital='Hospital Juan Negrín',
32
        ... )
33
        >>> # If you want to create the mapping before, or have spaces in the keys:
34
        >>> subject_dict = {
35
        ...     'one image': tio.ScalarImage('path_to_image.nii.gz'),
36
        ...     'a segmentation': tio.LabelMap('path_to_seg.nii.gz'),
37
        ...     'age': 45,
38
        ...     'name': 'John Doe',
39
        ...     'hospital': 'Hospital Juan Negrín',
40
        ... }
41
        >>> subject = tio.Subject(subject_dict)
42
43
    """
44
45
    def __init__(self, *args, **kwargs: Dict[str, Any]):
46
        if args:
47
            if len(args) == 1 and isinstance(args[0], dict):
48
                kwargs.update(args[0])
49
            else:
50
                message = (
51
                    'Only one dictionary as positional argument is allowed')
52
                raise ValueError(message)
53
        super().__init__(**kwargs)
54
        self._parse_images(self.get_images(intensity_only=False))
55
        self.update_attributes()  # this allows me to do e.g. subject.t1
56
        self.applied_transforms = []
57
58
    def __repr__(self):
59
        num_images = len(self.get_images(intensity_only=False))
60
        string = (
61
            f'{self.__class__.__name__}'
62
            f'(Keys: {tuple(self.keys())}; images: {num_images})'
63
        )
64
        return string
65
66
    def __copy__(self):
67
        result_dict = {}
68
        for key, value in self.items():
69
            if isinstance(value, Image):
70
                value = copy.copy(value)
71
            else:
72
                value = copy.deepcopy(value)
73
            result_dict[key] = value
74
        new = Subject(result_dict)
75
        new.applied_transforms = self.applied_transforms[:]
76
        return new
77
78
    def __len__(self):
79
        return len(self.get_images(intensity_only=False))
80
81
    @staticmethod
82
    def _parse_images(images: List[Tuple[str, Image]]) -> None:
83
        # Check that it's not empty
84
        if not images:
85
            raise ValueError('A subject without images cannot be created')
86
87
    @property
88
    def shape(self):
89
        """Return shape of first image in subject.
90
91
        Consistency of shapes across images in the subject is checked first.
92
        """
93
        self.check_consistent_attribute('shape')
94
        return self.get_first_image().shape
95
96
    @property
97
    def spatial_shape(self):
98
        """Return spatial shape of first image in subject.
99
100
        Consistency of spatial shapes across images in the subject is checked
101
        first.
102
        """
103
        self.check_consistent_spatial_shape()
104
        return self.get_first_image().spatial_shape
105
106
    @property
107
    def spacing(self):
108
        """Return spacing of first image in subject.
109
110
        Consistency of spacings across images in the subject is checked first.
111
        """
112
        self.check_consistent_attribute('spacing')
113
        return self.get_first_image().spacing
114
115
    @property
116
    def history(self):
117
        # Kept for backwards compatibility
118
        return self.get_applied_transforms()
119
120
    def is_2d(self):
121
        return all(i.is_2d() for i in self.get_images(intensity_only=False))
122
123
    def get_applied_transforms(
124
            self,
125
            ignore_intensity: bool = False,
126
            image_interpolation: Optional[str] = None,
127
            ) -> List['Transform']:
128
        from ..transforms.transform import Transform
129
        from ..transforms.intensity_transform import IntensityTransform
130
        name_to_transform = {
131
            cls.__name__: cls
132
            for cls in get_subclasses(Transform)
133
        }
134
        transforms_list = []
135
        for transform_name, arguments in self.applied_transforms:
136
            transform = name_to_transform[transform_name](**arguments)
137
            if ignore_intensity and isinstance(transform, IntensityTransform):
138
                continue
139
            resamples = hasattr(transform, 'image_interpolation')
140
            if resamples and image_interpolation is not None:
141
                parsed = transform.parse_interpolation(image_interpolation)
142
                transform.image_interpolation = parsed
143
            transforms_list.append(transform)
144
        return transforms_list
145
146
    def get_composed_history(
147
            self,
148
            ignore_intensity: bool = False,
149
            image_interpolation: Optional[str] = None,
150
            ) -> 'Compose':
151
        from ..transforms.augmentation.composition import Compose
152
        transforms = self.get_applied_transforms(
153
            ignore_intensity=ignore_intensity,
154
            image_interpolation=image_interpolation,
155
        )
156
        return Compose(transforms)
157
158
    def get_inverse_transform(
159
            self,
160
            warn: bool = True,
161
            ignore_intensity: bool = True,
162
            image_interpolation: Optional[str] = None,
163
            ) ->  'Compose':
164
        """Get a reversed list of the inverses of the applied transforms.
165
166
        Args:
167
            warn: Issue a warning if some transforms are not invertible.
168
            ignore_intensity: If ``True``, all instances of
169
                :class:`~torchio.transforms.intensity_transform.IntensityTransform`
170
                will be ignored.
171
            image_interpolation: Modify interpolation for scalar images inside
172
                transforms that perform resampling.
173
        """
174
        history_transform = self.get_composed_history(
175
            ignore_intensity=ignore_intensity,
176
            image_interpolation=image_interpolation,
177
        )
178
        inverse_transform = history_transform.inverse(warn=warn)
179
        return inverse_transform
180
181
    def apply_inverse_transform(self, **kwargs) -> 'Subject':
182
        """Try to apply the inverse of all applied transforms, in reverse order.
183
184
        Args:
185
            **kwargs: Keyword arguments passed on to
186
                :meth:`~torchio.data.subject.Subject.get_inverse_transform`.
187
        """
188
        inverse_transform = self.get_inverse_transform(**kwargs)
189
        transformed = inverse_transform(self)
190
        transformed.clear_history()
191
        return transformed
192
193
    def clear_history(self) -> None:
194
        self.applied_transforms = []
195
196
    def check_consistent_attribute(self, attribute: str) -> None:
197
        values_dict = {}
198
        iterable = self.get_images_dict(intensity_only=False).items()
199
        for image_name, image in iterable:
200
            values_dict[image_name] = getattr(image, attribute)
201
        num_unique_values = len(set(values_dict.values()))
202
        if num_unique_values > 1:
203
            message = (
204
                f'More than one {attribute} found in subject images:'
205
                f'\n{pprint.pformat(values_dict)}'
206
            )
207
            raise RuntimeError(message)
208
209
    def check_consistent_spatial_shape(self) -> None:
210
        self.check_consistent_attribute('spatial_shape')
211
212
    def check_consistent_orientation(self) -> None:
213
        self.check_consistent_attribute('orientation')
214
215
    def check_consistent_affine(self):
216
        # https://github.com/fepegar/torchio/issues/354
217
        affine = None
218
        first_image = None
219
        iterable = self.get_images_dict(intensity_only=False).items()
220
        for image_name, image in iterable:
221
            if affine is None:
222
                affine = image.affine
223
                first_image = image_name
224
            elif not np.allclose(affine, image.affine, rtol=1e-6, atol=1e-6):
225
                message = (
226
                    f'Images "{first_image}" and "{image_name}" do not occupy'
227
                    ' the same physical space.'
228
                    f'\nAffine of "{first_image}":'
229
                    f'\n{pprint.pformat(affine)}'
230
                    f'\nAffine of "{image_name}":'
231
                    f'\n{pprint.pformat(image.affine)}'
232
                )
233
                raise RuntimeError(message)
234
235
    def check_consistent_space(self):
236
        self.check_consistent_spatial_shape()
237
        self.check_consistent_affine()
238
239
    def get_images_dict(
240
            self,
241
            intensity_only=True,
242
            include: Optional[Sequence[str]] = None,
243
            exclude: Optional[Sequence[str]] = None,
244
            ) -> Dict[str, Image]:
245
        images = {}
246
        for image_name, image in self.items():
247
            if not isinstance(image, Image):
248
                continue
249
            if intensity_only and not image[TYPE] == INTENSITY:
250
                continue
251
            if include is not None and image_name not in include:
252
                continue
253
            if exclude is not None and image_name in exclude:
254
                continue
255
            images[image_name] = image
256
        return images
257
258
    def get_images(
259
            self,
260
            intensity_only=True,
261
            include: Optional[Sequence[str]] = None,
262
            exclude: Optional[Sequence[str]] = None,
263
            ) -> List[Image]:
264
        images_dict = self.get_images_dict(
265
            intensity_only=intensity_only,
266
            include=include,
267
            exclude=exclude,
268
        )
269
        return list(images_dict.values())
270
271
    def get_first_image(self) -> Image:
272
        return self.get_images(intensity_only=False)[0]
273
274
    # flake8: noqa: F821
275
    def add_transform(
276
            self,
277
            transform: 'Transform',
278
            parameters_dict: dict,
279
            ) -> None:
280
        self.applied_transforms.append((transform.name, parameters_dict))
281
282
    def load(self) -> None:
283
        """Load images in subject on RAM."""
284
        for image in self.get_images(intensity_only=False):
285
            image.load()
286
287
    def update_attributes(self) -> None:
288
        # This allows to get images using attribute notation, e.g. subject.t1
289
        self.__dict__.update(self)
290
291
    def add_image(self, image: Image, image_name: str) -> None:
292
        """Add an image."""
293
        self[image_name] = image
294
        self.update_attributes()
295
296
    def remove_image(self, image_name: str) -> None:
297
        """Remove an image."""
298
        del self[image_name]
299
        delattr(self, image_name)
300
301
    def plot(self, **kwargs) -> None:
302
        """Plot images using matplotlib.
303
304
        Args:
305
            **kwargs: Keyword arguments that will be passed on to
306
                :class:`~torchio.data.image.Image`.
307
        """
308
        from ..visualization import plot_subject  # avoid circular import
309
        plot_subject(self, **kwargs)
310