Passed
Push — master ( 4c09bc...11de88 )
by Fernando
01:29
created

torchio.data.subject.Subject.load()   A

Complexity

Conditions 2

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import copy
2
import pprint
3
from typing import Any, Dict, List, Tuple
4
from ..torchio import TYPE, INTENSITY
5
from .image import Image
6
7
8
class Subject(dict):
9
    """Class to store information about the images corresponding to a subject.
10
11
    Args:
12
        *args: If provided, a dictionary of items.
13
        **kwargs: Items that will be added to the subject sample.
14
15
    Example:
16
17
        >>> import torchio
18
        >>> from torchio import ScalarImage, LabelMap, Subject
19
        >>> # One way:
20
        >>> subject = Subject(
21
        ...     one_image=ScalarImage('path_to_image.nii.gz'),
22
        ...     a_segmentation=LabelMap('path_to_seg.nii.gz'),
23
        ...     age=45,
24
        ...     name='John Doe',
25
        ...     hospital='Hospital Juan Negrín',
26
        ... )
27
        >>> # If you want to create the mapping before, or have spaces in the keys:
28
        >>> subject_dict = {
29
        ...     'one image': ScalarImage('path_to_image.nii.gz'),
30
        ...     'a segmentation': LabelMap('path_to_seg.nii.gz'),
31
        ...     'age': 45,
32
        ...     'name': 'John Doe',
33
        ...     'hospital': 'Hospital Juan Negrín',
34
        ... }
35
        >>> Subject(subject_dict)
36
37
    """
38
39
    def __init__(self, *args, **kwargs: Dict[str, Any]):
40
        if args:
41
            if len(args) == 1 and isinstance(args[0], dict):
42
                kwargs.update(args[0])
43
            else:
44
                message = (
45
                    'Only one dictionary as positional argument is allowed')
46
                raise ValueError(message)
47
        super().__init__(**kwargs)
48
        self._parse_images(self.get_images(intensity_only=False))
49
        self.update_attributes()  # this allows me to do e.g. subject.t1
50
        self.history = []
51
52
    def __repr__(self):
53
        num_images = len(self.get_images(intensity_only=False))
54
        string = (
55
            f'{self.__class__.__name__}'
56
            f'(Keys: {tuple(self.keys())}; images: {num_images})'
57
        )
58
        return string
59
60
    def __copy__(self):
61
        result_dict = {}
62
        for key, value in self.items():
63
            if isinstance(value, Image):
64
                value = copy.copy(value)
65
            else:
66
                value = copy.deepcopy(value)
67
            result_dict[key] = value
68
        new = Subject(result_dict)
69
        new.history = self.history[:]
70
        return new
71
72
    @staticmethod
73
    def _parse_images(images: List[Tuple[str, Image]]) -> None:
74
        # Check that it's not empty
75
        if not images:
76
            raise ValueError('A subject without images cannot be created')
77
78
    @property
79
    def shape(self):
80
        """Return shape of first image in subject.
81
82
        Consistency of shapes across images in the subject is checked first.
83
        """
84
        self.check_consistent_attribute('shape')
85
        return self.get_first_image().shape
86
87
    @property
88
    def spatial_shape(self):
89
        """Return spatial shape of first image in subject.
90
91
        Consistency of spatial shapes across images in the subject is checked
92
        first.
93
        """
94
        self.check_consistent_spatial_shape()
95
        return self.get_first_image().spatial_shape
96
97
    @property
98
    def spacing(self):
99
        """Return spacing of first image in subject.
100
101
        Consistency of spacings across images in the subject is checked first.
102
        """
103
        self.check_consistent_attribute('spacing')
104
        return self.get_first_image().spacing
105
106
    def check_consistent_attribute(self, attribute: str) -> None:
107
        values_dict = {}
108
        iterable = self.get_images_dict(intensity_only=False).items()
109
        for image_name, image in iterable:
110
            values_dict[image_name] = getattr(image, attribute)
111
        num_unique_values = len(set(values_dict.values()))
112
        if num_unique_values > 1:
113
            message = (
114
                f'More than one {attribute} found in subject images:'
115
                f'\n{pprint.pformat(values_dict)}'
116
            )
117
            raise RuntimeError(message)
118
119
    def check_consistent_spatial_shape(self) -> None:
120
        self.check_consistent_attribute('spatial_shape')
121
122
    def get_images_dict(self, intensity_only=True):
123
        images = {}
124
        for image_name, image in self.items():
125
            if not isinstance(image, Image):
126
                continue
127
            if intensity_only and not image[TYPE] == INTENSITY:
128
                continue
129
            images[image_name] = image
130
        return images
131
132
    def get_images(self, intensity_only=True):
133
        images_dict = self.get_images_dict(intensity_only=intensity_only)
134
        return list(images_dict.values())
135
136
    def get_first_image(self):
137
        return self.get_images(intensity_only=False)[0]
138
139
    # flake8: noqa: F821
140
    def add_transform(
141
            self,
142
            transform: 'Transform',
143
            parameters_dict: dict,
144
            ) -> None:
145
        self.history.append((transform.name, parameters_dict))
146
147
    def load(self):
148
        for image in self.get_images(intensity_only=False):
149
            image.load()
150
151
    def crop(self, index_ini, index_fin):
152
        result_dict = {}
153
        for key, value in self.items():
154
            if isinstance(value, Image):
155
                # patch.clone() is much faster than copy.deepcopy(patch)
156
                value = value.crop(index_ini, index_fin)
157
            else:
158
                value = copy.deepcopy(value)
159
            result_dict[key] = value
160
        new = Subject(result_dict)
161
        new.history = self.history
162
        return new
163
164
    def update_attributes(self):
165
        # This allows to get images using attribute notation, e.g. subject.t1
166
        self.__dict__.update(self)
167
168
    def add_image(self, image: Image, image_name: str) -> None:
169
        self[image_name] = image
170
        self.update_attributes()
171
172
    def remove_image(self, image_name: str) -> None:
173
        del self[image_name]
174