Passed
Push — master ( 68316f...4d5c4f )
by Fernando
01:16
created

torchio.data.subject.Subject.get_images()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import copy
2
import pprint
3
from typing import Any, Dict, List, Tuple
4
5
import numpy as np
6
7
from ..torchio import TYPE, INTENSITY
8
from .image import Image
9
10
11
class Subject(dict):
12
    """Class to store information about the images corresponding to a subject.
13
14
    Args:
15
        *args: If provided, a dictionary of items.
16
        **kwargs: Items that will be added to the subject sample.
17
18
    Example:
19
20
        >>> import torchio as tio
21
        >>> # One way:
22
        >>> subject = tio.Subject(
23
        ...     one_image=tio.ScalarImage('path_to_image.nii.gz'),
24
        ...     a_segmentation=tio.LabelMap('path_to_seg.nii.gz'),
25
        ...     age=45,
26
        ...     name='John Doe',
27
        ...     hospital='Hospital Juan Negrín',
28
        ... )
29
        >>> # If you want to create the mapping before, or have spaces in the keys:
30
        >>> subject_dict = {
31
        ...     'one image': tio.ScalarImage('path_to_image.nii.gz'),
32
        ...     'a segmentation': tio.LabelMap('path_to_seg.nii.gz'),
33
        ...     'age': 45,
34
        ...     'name': 'John Doe',
35
        ...     'hospital': 'Hospital Juan Negrín',
36
        ... }
37
        >>> subject = tio.Subject(subject_dict)
38
39
    """
40
41
    def __init__(self, *args, **kwargs: Dict[str, Any]):
42
        if args:
43
            if len(args) == 1 and isinstance(args[0], dict):
44
                kwargs.update(args[0])
45
            else:
46
                message = (
47
                    'Only one dictionary as positional argument is allowed')
48
                raise ValueError(message)
49
        super().__init__(**kwargs)
50
        self._parse_images(self.get_images(intensity_only=False))
51
        self.update_attributes()  # this allows me to do e.g. subject.t1
52
        self.applied_transforms = []
53
54
    def __repr__(self):
55
        num_images = len(self.get_images(intensity_only=False))
56
        string = (
57
            f'{self.__class__.__name__}'
58
            f'(Keys: {tuple(self.keys())}; images: {num_images})'
59
        )
60
        return string
61
62
    def __copy__(self):
63
        result_dict = {}
64
        for key, value in self.items():
65
            if isinstance(value, Image):
66
                value = copy.copy(value)
67
            else:
68
                value = copy.deepcopy(value)
69
            result_dict[key] = value
70
        new = Subject(result_dict)
71
        new.applied_transforms = self.applied_transforms[:]
72
        return new
73
74
    def __len__(self):
75
        return len(self.get_images(intensity_only=False))
76
77
    @staticmethod
78
    def _parse_images(images: List[Tuple[str, Image]]) -> None:
79
        # Check that it's not empty
80
        if not images:
81
            raise ValueError('A subject without images cannot be created')
82
83
    @property
84
    def shape(self):
85
        """Return shape of first image in subject.
86
87
        Consistency of shapes across images in the subject is checked first.
88
        """
89
        self.check_consistent_attribute('shape')
90
        return self.get_first_image().shape
91
92
    @property
93
    def spatial_shape(self):
94
        """Return spatial shape of first image in subject.
95
96
        Consistency of spatial shapes across images in the subject is checked
97
        first.
98
        """
99
        self.check_consistent_spatial_shape()
100
        return self.get_first_image().spatial_shape
101
102
    @property
103
    def spacing(self):
104
        """Return spacing of first image in subject.
105
106
        Consistency of spacings across images in the subject is checked first.
107
        """
108
        self.check_consistent_attribute('spacing')
109
        return self.get_first_image().spacing
110
111
    @property
112
    def history(self):
113
        from .. import transforms
114
        transforms_list = []
115
        for transform_name, arguments in self.applied_transforms:
116
            transform = getattr(transforms, transform_name)(**arguments)
117
            transforms_list.append(transform)
118
        return transforms_list
119
120
    def get_composed_history(self) -> 'Transform':
121
        from ..transforms.augmentation.composition import Compose
122
        return Compose(self.history)
123
124
    def get_inverse_transform(self) -> 'Transform':
125
        return self.get_composed_history().inverse()
126
127
    def apply_inverse_transform(self) -> 'Subject':
128
        return self.get_inverse_transform()(self)
129
130
    def clear_history(self) -> None:
131
        self.applied_transforms = []
132
133
    def check_consistent_attribute(self, attribute: str) -> None:
134
        values_dict = {}
135
        iterable = self.get_images_dict(intensity_only=False).items()
136
        for image_name, image in iterable:
137
            values_dict[image_name] = getattr(image, attribute)
138
        num_unique_values = len(set(values_dict.values()))
139
        if num_unique_values > 1:
140
            message = (
141
                f'More than one {attribute} found in subject images:'
142
                f'\n{pprint.pformat(values_dict)}'
143
            )
144
            raise RuntimeError(message)
145
146
    def check_consistent_spatial_shape(self) -> None:
147
        self.check_consistent_attribute('spatial_shape')
148
149
    def check_consistent_orientation(self) -> None:
150
        self.check_consistent_attribute('orientation')
151
152
    def check_consistent_affine(self):
153
        # https://github.com/fepegar/torchio/issues/354
154
        affine = None
155
        first_image = None
156
        iterable = self.get_images_dict(intensity_only=False).items()
157
        for image_name, image in iterable:
158
            if affine is None:
159
                affine = image.affine
160
                first_image = image_name
161
            elif not np.allclose(affine, image.affine, rtol=1e-7):
162
                message = (
163
                    f'Images "{first_image}" and "{image_name}" do not occupy'
164
                    ' the same physical space.'
165
                    f'\nAffine of "{first_image}":'
166
                    f'\n{pprint.pformat(affine)}'
167
                    f'\nAffine of "{image_name}":'
168
                    f'\n{pprint.pformat(image.affine)}'
169
                )
170
                raise RuntimeError(message)
171
172
    def check_consistent_space(self):
173
        self.check_consistent_spatial_shape()
174
        self.check_consistent_affine()
175
176
    def get_images_dict(self, intensity_only=True) -> Dict[str, Image]:
177
        images = {}
178
        for image_name, image in self.items():
179
            if not isinstance(image, Image):
180
                continue
181
            if intensity_only and not image[TYPE] == INTENSITY:
182
                continue
183
            images[image_name] = image
184
        return images
185
186
    def get_images(self, intensity_only=True) -> List[Image]:
187
        images_dict = self.get_images_dict(intensity_only=intensity_only)
188
        return list(images_dict.values())
189
190
    def get_first_image(self) -> Image:
191
        return self.get_images(intensity_only=False)[0]
192
193
    # flake8: noqa: F821
194
    def add_transform(
195
            self,
196
            transform: 'Transform',
197
            parameters_dict: dict,
198
            ) -> None:
199
        self.applied_transforms.append((transform.name, parameters_dict))
200
201
    def load(self) -> None:
202
        """Load images in subject."""
203
        for image in self.get_images(intensity_only=False):
204
            image.load()
205
206
    def update_attributes(self) -> None:
207
        # This allows to get images using attribute notation, e.g. subject.t1
208
        self.__dict__.update(self)
209
210
    def add_image(self, image: Image, image_name: str) -> None:
211
        """Add an image."""
212
        self[image_name] = image
213
        self.update_attributes()
214
215
    def remove_image(self, image_name: str) -> None:
216
        """Remove an image."""
217
        del self[image_name]
218
219
    def plot(self, **kwargs) -> None:
220
        """Plot images."""
221
        from ..visualization import plot_subject  # avoid circular import
222
        plot_subject(self, **kwargs)
223