Passed
Pull Request — master (#226)
by Fernando
01:31
created

torchio.data.subject.Subject.__repr__()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 1
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
import copy
2
import json
3
import pprint
4
from collections import OrderedDict
5
from typing import Any, Dict, List, Tuple
6
from ..torchio import TYPE, INTENSITY
7
from ..utils import get_transform_class
8
from .image import Image
9
10
11
class Subject(dict):
12
    """Class to store information about the images corresponding to a subject.
13
14
    Args:
15
        *args: If provided, a dictionary of items.
16
        **kwargs: Items that will be added to the subject sample.
17
18
    Example:
19
20
        >>> import torchio
21
        >>> from torchio import Image, Subject
22
        >>> # One way:
23
        >>> subject = Subject(
24
        ...     one_image=Image('path_to_image.nii.gz', type=torchio.INTENSITY),
25
        ...     a_segmentation=Image('path_to_seg.nii.gz', type=torchio.LABEL),
26
        ...     age=45,
27
        ...     name='John Doe',
28
        ...     hospital='Hospital Juan Negrín',
29
        ... )
30
        >>> # If you want to create the mapping before, or have spaces in the keys:
31
        >>> subject_dict = {
32
        ...     'one image': Image('path_to_image.nii.gz', type=torchio.INTENSITY),
33
        ...     'a segmentation': Image('path_to_seg.nii.gz', type=torchio.LABEL),
34
        ...     'age': 45,
35
        ...     'name': 'John Doe',
36
        ...     'hospital': 'Hospital Juan Negrín',
37
        ... }
38
        >>> Subject(subject_dict)
39
40
    """
41
42
    def __init__(self, *args, **kwargs: Dict[str, Any]):
43
        if args:
44
            if len(args) == 1 and isinstance(args[0], dict):
45
                kwargs.update(args[0])
46
            else:
47
                message = (
48
                    'Only one dictionary as positional argument is allowed')
49
                raise ValueError(message)
50
        super().__init__(**kwargs)
51
        self.images = [
52
            (k, v) for (k, v) in self.items()
53
            if isinstance(v, Image)
54
        ]
55
        self._parse_images(self.images)
56
        self['history'] = '{}'
57
        self.__dict__.update(self)  # this allows me to do e.g. subject.t1
58
59
    def __repr__(self):
60
        string = (
61
            f'{self.__class__.__name__}'
62
            f'(Keys: {tuple(self.keys())}; images: {len(self.images)})'
63
        )
64
        return string
65
66
    @staticmethod
67
    def _parse_images(images: List[Tuple[str, Image]]) -> None:
68
        # Check that it's not empty
69
        if not images:
70
            raise ValueError('A subject without images cannot be created')
71
72
    @property
73
    def shape(self):
74
        """Return shape of first image in subject.
75
76
        Consistency of shapes across images in the subject is checked first.
77
        """
78
        self.check_consistent_shape()
79
        image = self.get_images(intensity_only=False)[0]
80
        return image.shape
81
82
    @property
83
    def spatial_shape(self):
84
        """Return spatial shape of first image in subject.
85
86
        Consistency of shapes across images in the subject is checked first.
87
        """
88
        return self.shape[1:]
89
90
    @property
91
    def spacing(self):
92
        """Return spacing of first image in subject.
93
94
        Consistency of shapes across images in the subject is checked first.
95
        """
96
        self.check_consistent_shape()
97
        image = self.get_images(intensity_only=False)[0]
98
        return image.spacing
99
100
    def get_images_dict(self, intensity_only=True):
101
        images = {}
102
        for image_name, image in self.items():
103
            if not isinstance(image, Image):
104
                continue
105
            if intensity_only and not image[TYPE] == INTENSITY:
106
                continue
107
            images[image_name] = image
108
        return images
109
110
    def get_images(self, intensity_only=True):
111
        images_dict = self.get_images_dict(intensity_only=intensity_only)
112
        return list(images_dict.values())
113
114
    def check_consistent_shape(self) -> None:
115
        shapes_dict = {}
116
        iterable = self.get_images_dict(intensity_only=False).items()
117
        for image_name, image in iterable:
118
            shapes_dict[image_name] = image.shape
119
        num_unique_shapes = len(set(shapes_dict.values()))
120
        if num_unique_shapes > 1:
121
            message = (
122
                'Images in sample have inconsistent shapes:'
123
                f'\n{pprint.pformat(shapes_dict)}'
124
            )
125
            raise ValueError(message)
126
127
    def add_transform(
128
            self,
129
            transform: 'Transform',
130
            seed: int,
131
            ) -> None:
132
        dictionary = json.loads(self['history'], object_pairs_hook=OrderedDict)
133
        dictionary[transform.__class__.__name__] = seed
134
        string = json.dumps(dictionary)
135
        self['history'] = string
136
137
    def get_applied_transforms(self):
138
        dictionary = json.loads(self['history'], object_pairs_hook=OrderedDict)
139
        classes_and_seeds = [
140
            (get_transform_class(name), seed)
141
            for (name, seed) in dictionary.items()
142
        ]
143
        return classes_and_seeds
144
145
    def load(self):
146
        for image in self.get_images(intensity_only=False):
147
            image.load()
148
149
    def crop(self, index_ini, index_fin):
150
        result_dict = {}
151
        for key, value in self.items():
152
            if isinstance(value, Image):
153
                # patch.clone() is much faster than copy.deepcopy(patch)
154
                value = value.crop(index_ini, index_fin)
155
            else:
156
                value = copy.deepcopy(value)
157
            result_dict[key] = value
158
        return Subject(result_dict)
159