Passed
Pull Request — master (#353)
by Fernando
01:07
created

torchio.data.subject.Subject.get_first_image()   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
import copy
2
import pprint
3
from typing import Any, Dict, List, Tuple
4
from ..torchio import TYPE, INTENSITY
5
from .image import Image
6
7
8
class Subject(dict):
9
    """Class to store information about the images corresponding to a subject.
10
11
    Args:
12
        *args: If provided, a dictionary of items.
13
        **kwargs: Items that will be added to the subject sample.
14
15
    Example:
16
17
        >>> from torchio import ScalarImage, LabelMap, Subject
18
        >>> # One way:
19
        >>> subject = Subject(
20
        ...     one_image=ScalarImage('path_to_image.nii.gz'),
21
        ...     a_segmentation=LabelMap('path_to_seg.nii.gz'),
22
        ...     age=45,
23
        ...     name='John Doe',
24
        ...     hospital='Hospital Juan Negrín',
25
        ... )
26
        >>> # If you want to create the mapping before, or have spaces in the keys:
27
        >>> subject_dict = {
28
        ...     'one image': ScalarImage('path_to_image.nii.gz'),
29
        ...     'a segmentation': LabelMap('path_to_seg.nii.gz'),
30
        ...     'age': 45,
31
        ...     'name': 'John Doe',
32
        ...     'hospital': 'Hospital Juan Negrín',
33
        ... }
34
        >>> Subject(subject_dict)
35
36
    """
37
38
    def __init__(self, *args, **kwargs: Dict[str, Any]):
39
        if args:
40
            if len(args) == 1 and isinstance(args[0], dict):
41
                kwargs.update(args[0])
42
            else:
43
                message = (
44
                    'Only one dictionary as positional argument is allowed')
45
                raise ValueError(message)
46
        super().__init__(**kwargs)
47
        self._parse_images(self.get_images(intensity_only=False))
48
        self.update_attributes()  # this allows me to do e.g. subject.t1
49
        self.applied_transforms = []
50
51
    def __repr__(self):
52
        num_images = len(self.get_images(intensity_only=False))
53
        string = (
54
            f'{self.__class__.__name__}'
55
            f'(Keys: {tuple(self.keys())}; images: {num_images})'
56
        )
57
        return string
58
59
    def __copy__(self):
60
        result_dict = {}
61
        for key, value in self.items():
62
            if isinstance(value, Image):
63
                value = copy.copy(value)
64
            else:
65
                value = copy.deepcopy(value)
66
            result_dict[key] = value
67
        new = Subject(result_dict)
68
        new.applied_transforms = self.applied_transforms[:]
69
        return new
70
71
    def __len__(self):
72
        return len(self.get_images(intensity_only=False))
73
74
    @staticmethod
75
    def _parse_images(images: List[Tuple[str, Image]]) -> None:
76
        # Check that it's not empty
77
        if not images:
78
            raise ValueError('A subject without images cannot be created')
79
80
    @property
81
    def shape(self):
82
        """Return shape of first image in subject.
83
84
        Consistency of shapes across images in the subject is checked first.
85
        """
86
        self.check_consistent_attribute('shape')
87
        return self.get_first_image().shape
88
89
    @property
90
    def spatial_shape(self):
91
        """Return spatial shape of first image in subject.
92
93
        Consistency of spatial shapes across images in the subject is checked
94
        first.
95
        """
96
        self.check_consistent_spatial_shape()
97
        return self.get_first_image().spatial_shape
98
99
    @property
100
    def spacing(self):
101
        """Return spacing of first image in subject.
102
103
        Consistency of spacings across images in the subject is checked first.
104
        """
105
        self.check_consistent_attribute('spacing')
106
        return self.get_first_image().spacing
107
108
    @property
109
    def history(self):
110
        from .. import transforms
111
        transforms_list = []
112
        for transform_name, arguments in self.applied_transforms:
113
            transform = getattr(transforms, transform_name)(**arguments)
114
            transforms_list.append(transform)
115
        return transforms_list
116
117
    def get_composed_history(self):
118
        from ..transforms.augmentation.composition import Compose
119
        return Compose(self.history)
120
121
    def check_consistent_attribute(self, attribute: str) -> None:
122
        values_dict = {}
123
        iterable = self.get_images_dict(intensity_only=False).items()
124
        for image_name, image in iterable:
125
            values_dict[image_name] = getattr(image, attribute)
126
        num_unique_values = len(set(values_dict.values()))
127
        if num_unique_values > 1:
128
            message = (
129
                f'More than one {attribute} found in subject images:'
130
                f'\n{pprint.pformat(values_dict)}'
131
            )
132
            raise RuntimeError(message)
133
134
    def check_consistent_spatial_shape(self) -> None:
135
        self.check_consistent_attribute('spatial_shape')
136
137
    def check_consistent_orientation(self) -> None:
138
        self.check_consistent_attribute('orientation')
139
140
    def get_images_dict(self, intensity_only=True):
141
        images = {}
142
        for image_name, image in self.items():
143
            if not isinstance(image, Image):
144
                continue
145
            if intensity_only and not image[TYPE] == INTENSITY:
146
                continue
147
            images[image_name] = image
148
        return images
149
150
    def get_images(self, intensity_only=True):
151
        images_dict = self.get_images_dict(intensity_only=intensity_only)
152
        return list(images_dict.values())
153
154
    def get_first_image(self):
155
        return self.get_images(intensity_only=False)[0]
156
157
    # flake8: noqa: F821
158
    def add_transform(
159
            self,
160
            transform: 'Transform',
161
            parameters_dict: dict,
162
            ) -> None:
163
        self.applied_transforms.append((transform.name, parameters_dict))
164
165
    def load(self):
166
        """Load images in subject."""
167
        for image in self.get_images(intensity_only=False):
168
            image.load()
169
170
    def update_attributes(self):
171
        # This allows to get images using attribute notation, e.g. subject.t1
172
        self.__dict__.update(self)
173
174
    def add_image(self, image: Image, image_name: str) -> None:
175
        """Add an image."""
176
        self[image_name] = image
177
        self.update_attributes()
178
179
    def remove_image(self, image_name: str) -> None:
180
        """Remove an image."""
181
        del self[image_name]
182
183
    def plot(self, **kwargs) -> None:
184
        """Plot images."""
185
        from ..visualization import plot_subject  # avoid circular import
186
        plot_subject(self, **kwargs)
187