Passed
Pull Request — master (#417)
by Fernando
01:19
created

torchio.data.subject.Subject._parse_images()   A

Complexity

Conditions 2

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 1
dl 0
loc 5
rs 10
c 0
b 0
f 0
1
import copy
2
import pprint
3
from typing import Any, Dict, List, Tuple, Optional, Sequence
4
5
import numpy as np
6
7
from ..constants import TYPE, INTENSITY
8
from .image import Image
9
from ..utils import get_subclasses
10
11
12
class Subject(dict):
13
    """Class to store information about the images corresponding to a subject.
14
15
    Args:
16
        *args: If provided, a dictionary of items.
17
        **kwargs: Items that will be added to the subject sample.
18
19
    Example:
20
21
        >>> import torchio as tio
22
        >>> # One way:
23
        >>> subject = tio.Subject(
24
        ...     one_image=tio.ScalarImage('path_to_image.nii.gz'),
25
        ...     a_segmentation=tio.LabelMap('path_to_seg.nii.gz'),
26
        ...     age=45,
27
        ...     name='John Doe',
28
        ...     hospital='Hospital Juan Negrín',
29
        ... )
30
        >>> # If you want to create the mapping before, or have spaces in the keys:
31
        >>> subject_dict = {
32
        ...     'one image': tio.ScalarImage('path_to_image.nii.gz'),
33
        ...     'a segmentation': tio.LabelMap('path_to_seg.nii.gz'),
34
        ...     'age': 45,
35
        ...     'name': 'John Doe',
36
        ...     'hospital': 'Hospital Juan Negrín',
37
        ... }
38
        >>> subject = tio.Subject(subject_dict)
39
40
    """
41
42
    def __init__(self, *args, **kwargs: Dict[str, Any]):
43
        if args:
44
            if len(args) == 1 and isinstance(args[0], dict):
45
                kwargs.update(args[0])
46
            else:
47
                message = (
48
                    'Only one dictionary as positional argument is allowed')
49
                raise ValueError(message)
50
        super().__init__(**kwargs)
51
        self._parse_images(self.get_images(intensity_only=False))
52
        self.update_attributes()  # this allows me to do e.g. subject.t1
53
        self.applied_transforms = []
54
55
    def __repr__(self):
56
        num_images = len(self.get_images(intensity_only=False))
57
        string = (
58
            f'{self.__class__.__name__}'
59
            f'(Keys: {tuple(self.keys())}; images: {num_images})'
60
        )
61
        return string
62
63
    def __copy__(self):
64
        result_dict = {}
65
        for key, value in self.items():
66
            if isinstance(value, Image):
67
                value = copy.copy(value)
68
            else:
69
                value = copy.deepcopy(value)
70
            result_dict[key] = value
71
        new = Subject(result_dict)
72
        new.applied_transforms = self.applied_transforms[:]
73
        return new
74
75
    def __len__(self):
76
        return len(self.get_images(intensity_only=False))
77
78
    @staticmethod
79
    def _parse_images(images: List[Tuple[str, Image]]) -> None:
80
        # Check that it's not empty
81
        if not images:
82
            raise ValueError('A subject without images cannot be created')
83
84
    @property
85
    def shape(self):
86
        """Return shape of first image in subject.
87
88
        Consistency of shapes across images in the subject is checked first.
89
        """
90
        self.check_consistent_attribute('shape')
91
        return self.get_first_image().shape
92
93
    @property
94
    def spatial_shape(self):
95
        """Return spatial shape of first image in subject.
96
97
        Consistency of spatial shapes across images in the subject is checked
98
        first.
99
        """
100
        self.check_consistent_spatial_shape()
101
        return self.get_first_image().spatial_shape
102
103
    @property
104
    def spacing(self):
105
        """Return spacing of first image in subject.
106
107
        Consistency of spacings across images in the subject is checked first.
108
        """
109
        self.check_consistent_attribute('spacing')
110
        return self.get_first_image().spacing
111
112
    @property
113
    def history(self):
114
        # Kept for backwards compatibility
115
        return self.get_applied_transforms()
116
117
    def get_applied_transforms(self, ignore_intensity: bool = False):
118
        from ..transforms.transform import Transform
119
        from ..transforms.intensity_transform import IntensityTransform
120
        name_to_transform = {
121
            cls.__name__: cls
122
            for cls in get_subclasses(Transform)
123
        }
124
        transforms_list = []
125
        for transform_name, arguments in self.applied_transforms:
126
            transform = name_to_transform[transform_name](**arguments)
127
            if ignore_intensity and isinstance(transform, IntensityTransform):
128
                continue
129
            transforms_list.append(transform)
130
        return transforms_list
131
132
    def get_composed_history(
133
            self,
134
            ignore_intensity: bool = False,
135
            ) -> 'Transform':
136
        from ..transforms.augmentation.composition import Compose
137
        transforms = self.get_applied_transforms(
138
            ignore_intensity=ignore_intensity)
139
        return Compose(transforms)
140
141
    def get_inverse_transform(
142
            self,
143
            warn: bool = True,
144
            ignore_intensity: bool = True,
145
            ) ->  'Transform':
146
        history_transform = self.get_composed_history(
147
            ignore_intensity=ignore_intensity)
148
        inverse_transform = history_transform.inverse(warn=warn)
149
        return inverse_transform
150
151
    def apply_inverse_transform(
152
            self,
153
            warn: bool = True,
154
            ignore_intensity: bool = True,
155
            ) -> 'Subject':
156
        inverse_transform = self.get_inverse_transform(
157
            warn=warn,
158
            ignore_intensity=ignore_intensity,
159
        )
160
        transformed = inverse_transform(self)
161
        transformed.clear_history()
162
        return transformed
163
164
    def clear_history(self) -> None:
165
        self.applied_transforms = []
166
167
    def check_consistent_attribute(self, attribute: str) -> None:
168
        values_dict = {}
169
        iterable = self.get_images_dict(intensity_only=False).items()
170
        for image_name, image in iterable:
171
            values_dict[image_name] = getattr(image, attribute)
172
        num_unique_values = len(set(values_dict.values()))
173
        if num_unique_values > 1:
174
            message = (
175
                f'More than one {attribute} found in subject images:'
176
                f'\n{pprint.pformat(values_dict)}'
177
            )
178
            raise RuntimeError(message)
179
180
    def check_consistent_spatial_shape(self) -> None:
181
        self.check_consistent_attribute('spatial_shape')
182
183
    def check_consistent_orientation(self) -> None:
184
        self.check_consistent_attribute('orientation')
185
186
    def check_consistent_affine(self):
187
        # https://github.com/fepegar/torchio/issues/354
188
        affine = None
189
        first_image = None
190
        iterable = self.get_images_dict(intensity_only=False).items()
191
        for image_name, image in iterable:
192
            if affine is None:
193
                affine = image.affine
194
                first_image = image_name
195
            elif not np.allclose(affine, image.affine, rtol=1e-6, atol=1e-6):
196
                message = (
197
                    f'Images "{first_image}" and "{image_name}" do not occupy'
198
                    ' the same physical space.'
199
                    f'\nAffine of "{first_image}":'
200
                    f'\n{pprint.pformat(affine)}'
201
                    f'\nAffine of "{image_name}":'
202
                    f'\n{pprint.pformat(image.affine)}'
203
                )
204
                raise RuntimeError(message)
205
206
    def check_consistent_space(self):
207
        self.check_consistent_spatial_shape()
208
        self.check_consistent_affine()
209
210
    def get_images_dict(
211
            self,
212
            intensity_only=True,
213
            include: Optional[Sequence[str]] = None,
214
            exclude: Optional[Sequence[str]] = None,
215
            ) -> Dict[str, Image]:
216
        images = {}
217
        for image_name, image in self.items():
218
            if not isinstance(image, Image):
219
                continue
220
            if intensity_only and not image[TYPE] == INTENSITY:
221
                continue
222
            if include is not None and image_name not in include:
223
                continue
224
            if exclude is not None and image_name in exclude:
225
                continue
226
            images[image_name] = image
227
        return images
228
229
    def get_images(
230
            self,
231
            intensity_only=True,
232
            include: Optional[Sequence[str]] = None,
233
            exclude: Optional[Sequence[str]] = None,
234
            ) -> List[Image]:
235
        images_dict = self.get_images_dict(intensity_only=intensity_only, include=include, exclude=exclude)
236
        return list(images_dict.values())
237
238
    def get_first_image(self) -> Image:
239
        return self.get_images(intensity_only=False)[0]
240
241
    # flake8: noqa: F821
242
    def add_transform(
243
            self,
244
            transform: 'Transform',
245
            parameters_dict: dict,
246
            ) -> None:
247
        self.applied_transforms.append((transform.name, parameters_dict))
248
249
    def load(self) -> None:
250
        """Load images in subject."""
251
        for image in self.get_images(intensity_only=False):
252
            image.load()
253
254
    def update_attributes(self) -> None:
255
        # This allows to get images using attribute notation, e.g. subject.t1
256
        self.__dict__.update(self)
257
258
    def add_image(self, image: Image, image_name: str) -> None:
259
        """Add an image."""
260
        self[image_name] = image
261
        self.update_attributes()
262
263
    def remove_image(self, image_name: str) -> None:
264
        """Remove an image."""
265
        del self[image_name]
266
267
    def plot(self, **kwargs) -> None:
268
        """Plot images."""
269
        from ..visualization import plot_subject  # avoid circular import
270
        plot_subject(self, **kwargs)
271