Passed
Pull Request — master (#226)
by Fernando
01:12
created

Subject.get_applied_transforms()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import copy
2
import json
3
import pprint
4
from collections import OrderedDict
5
from typing import Any, Dict, List, Tuple
6
from ..torchio import TYPE, INTENSITY
7
from .image import Image
8
9
10
class Subject(dict):
11
    """Class to store information about the images corresponding to a subject.
12
13
    Args:
14
        *args: If provided, a dictionary of items.
15
        **kwargs: Items that will be added to the subject sample.
16
17
    Example:
18
19
        >>> import torchio
20
        >>> from torchio import Image, Subject
21
        >>> # One way:
22
        >>> subject = Subject(
23
        ...     one_image=Image('path_to_image.nii.gz', type=torchio.INTENSITY),
24
        ...     a_segmentation=Image('path_to_seg.nii.gz', type=torchio.LABEL),
25
        ...     age=45,
26
        ...     name='John Doe',
27
        ...     hospital='Hospital Juan Negrín',
28
        ... )
29
        >>> # If you want to create the mapping before, or have spaces in the keys:
30
        >>> subject_dict = {
31
        ...     'one image': Image('path_to_image.nii.gz', type=torchio.INTENSITY),
32
        ...     'a segmentation': Image('path_to_seg.nii.gz', type=torchio.LABEL),
33
        ...     'age': 45,
34
        ...     'name': 'John Doe',
35
        ...     'hospital': 'Hospital Juan Negrín',
36
        ... }
37
        >>> Subject(subject_dict)
38
39
    """
40
41
    def __init__(self, *args, **kwargs: Dict[str, Any]):
42
        if args:
43
            if len(args) == 1 and isinstance(args[0], dict):
44
                kwargs.update(args[0])
45
            else:
46
                message = (
47
                    'Only one dictionary as positional argument is allowed')
48
                raise ValueError(message)
49
        super().__init__(**kwargs)
50
        self.images = [
51
            (k, v) for (k, v) in self.items()
52
            if isinstance(v, Image)
53
        ]
54
        self._parse_images(self.images)
55
        self['history'] = '{}'
56
        self.__dict__.update(self)  # this allows me to do e.g. subject.t1
57
58
    def __repr__(self):
59
        string = (
60
            f'{self.__class__.__name__}'
61
            f'(Keys: {tuple(self.keys())}; images: {len(self.images)})'
62
        )
63
        return string
64
65
    @staticmethod
66
    def _parse_images(images: List[Tuple[str, Image]]) -> None:
67
        # Check that it's not empty
68
        if not images:
69
            raise ValueError('A subject without images cannot be created')
70
71
    @property
72
    def shape(self):
73
        """Return shape of first image in subject.
74
75
        Consistency of shapes across images in the subject is checked first.
76
        """
77
        self.check_consistent_shape()
78
        image = self.get_images(intensity_only=False)[0]
79
        return image.shape
80
81
    @property
82
    def spatial_shape(self):
83
        """Return spatial shape of first image in subject.
84
85
        Consistency of shapes across images in the subject is checked first.
86
        """
87
        return self.shape[1:]
88
89
    @property
90
    def spacing(self):
91
        """Return spacing of first image in subject.
92
93
        Consistency of shapes across images in the subject is checked first.
94
        """
95
        self.check_consistent_shape()
96
        image = self.get_images(intensity_only=False)[0]
97
        return image.spacing
98
99
    def get_images_dict(self, intensity_only=True):
100
        images = {}
101
        for image_name, image in self.items():
102
            if not isinstance(image, Image):
103
                continue
104
            if intensity_only and not image[TYPE] == INTENSITY:
105
                continue
106
            images[image_name] = image
107
        return images
108
109
    def get_images(self, intensity_only=True):
110
        images_dict = self.get_images_dict(intensity_only=intensity_only)
111
        return list(images_dict.values())
112
113
    def check_consistent_shape(self) -> None:
114
        shapes_dict = {}
115
        iterable = self.get_images_dict(intensity_only=False).items()
116
        for image_name, image in iterable:
117
            shapes_dict[image_name] = image.shape
118
        num_unique_shapes = len(set(shapes_dict.values()))
119
        if num_unique_shapes > 1:
120
            message = (
121
                'Images in sample have inconsistent shapes:'
122
                f'\n{pprint.pformat(shapes_dict)}'
123
            )
124
            raise ValueError(message)
125
126
    def add_transform(
127
            self,
128
            transform: 'Transform',
129
            seed: int,
130
            ) -> None:
131
        dictionary = json.loads(self['history'], object_pairs_hook=OrderedDict)
132
        dictionary[transform.__class__.__name__] = seed
133
        string = json.dumps(dictionary)
134
        self['history'] = string
135
136
    def get_applied_transforms(self):
137
        dictionary = json.loads(self['history'], object_pairs_hook=OrderedDict)
138
        return [(name, params) for (name, params) in dictionary.items()]
139
140
    def load(self):
141
        for image in self.get_images(intensity_only=False):
142
            image.load()
143
144
    def crop(self, index_ini, index_fin):
145
        result_dict = {}
146
        for key, value in self.items():
147
            if isinstance(value, Image):
148
                # patch.clone() is much faster than copy.deepcopy(patch)
149
                value = value.crop(index_ini, index_fin)
150
            else:
151
                value = copy.deepcopy(value)
152
            result_dict[key] = value
153
        return Subject(result_dict)
154