Completed
Push — master ( 876b6a...63061a )
by Fernando
02:53 queued 01:24
created

torchio.data.subject.Subject.load()   A

Complexity

Conditions 2

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 3
nop 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
import pprint
2
from typing import (
3
    Any,
4
    Dict,
5
    List,
6
    Tuple,
7
)
8
from ..torchio import TYPE, INTENSITY
9
from .image import Image
10
11
12
class Subject(dict):
13
    """Class to store information about the images corresponding to a subject.
14
15
    Args:
16
        *args: If provided, a dictionary of items.
17
        **kwargs: Items that will be added to the subject sample.
18
19
    Example:
20
21
        >>> import torchio
22
        >>> from torchio import Image, Subject
23
        >>> # One way:
24
        >>> subject = Subject(
25
        ...     one_image=Image('path_to_image.nii.gz', type=torchio.INTENSITY),
26
        ...     a_segmentation=Image('path_to_seg.nii.gz', type=torchio.LABEL),
27
        ...     age=45,
28
        ...     name='John Doe',
29
        ...     hospital='Hospital Juan Negrín',
30
        ... )
31
        >>> # If you want to create the mapping before, or have spaces in the keys:
32
        >>> subject_dict = {
33
        ...     'one image': Image('path_to_image.nii.gz', type=torchio.INTENSITY),
34
        ...     'a segmentation': Image('path_to_seg.nii.gz', type=torchio.LABEL),
35
        ...     'age': 45,
36
        ...     'name': 'John Doe',
37
        ...     'hospital': 'Hospital Juan Negrín',
38
        ... }
39
        >>> Subject(subject_dict)
40
41
    """
42
43
    def __init__(self, *args, **kwargs: Dict[str, Any]):
44
        if args:
45
            if len(args) == 1 and isinstance(args[0], dict):
46
                kwargs.update(args[0])
47
            else:
48
                message = (
49
                    'Only one dictionary as positional argument is allowed')
50
                raise ValueError(message)
51
        super().__init__(**kwargs)
52
        self.images = [
53
            (k, v) for (k, v) in self.items()
54
            if isinstance(v, Image)
55
        ]
56
        self._parse_images(self.images)
57
        self.__dict__.update(self)  # this allows me to do e.g. subject.t1
58
        self.is_sample = False  # set to True by ImagesDataset
59
        self.history = []
60
61
    def __repr__(self):
62
        string = (
63
            f'{self.__class__.__name__}'
64
            f'(Keys: {tuple(self.keys())}; images: {len(self.images)})'
65
        )
66
        return string
67
68
    @staticmethod
69
    def _parse_images(images: List[Tuple[str, Image]]) -> None:
70
        # Check that it's not empty
71
        if not images:
72
            raise ValueError('A subject without images cannot be created')
73
74
    @property
75
    def shape(self):
76
        """Return shape of first image in sample.
77
78
        Consistency of shapes across images in the sample is checked first.
79
        """
80
        self.check_consistent_shape()
81
        image = self.get_images(intensity_only=False)[0]
82
        return image.shape
83
84
    @property
85
    def spatial_shape(self):
86
        """Return spatial shape of first image in sample.
87
88
        Consistency of shapes across images in the sample is checked first.
89
        """
90
        return self.shape[1:]
91
92
    def get_images_dict(self, intensity_only=True):
93
        images = {}
94
        for image_name, image in self.items():
95
            if not isinstance(image, Image):
96
                continue
97
            if intensity_only and not image[TYPE] == INTENSITY:
98
                continue
99
            images[image_name] = image
100
        return images
101
102
    def get_images(self, intensity_only=True):
103
        images_dict = self.get_images_dict(intensity_only=intensity_only)
104
        return list(images_dict.values())
105
106
    def check_consistent_shape(self) -> None:
107
        shapes_dict = {}
108
        iterable = self.get_images_dict(intensity_only=False).items()
109
        for image_name, image in iterable:
110
            shapes_dict[image_name] = image.shape
111
        num_unique_shapes = len(set(shapes_dict.values()))
112
        if num_unique_shapes > 1:
113
            message = (
114
                'Images in sample have inconsistent shapes:'
115
                f'\n{pprint.pformat(shapes_dict)}'
116
            )
117
            raise ValueError(message)
118
119
    def add_transform(
120
            self,
121
            transform: 'Transform',
122
            parameters_dict: dict,
123
            ) -> None:
124
        self.history.append((transform.name, parameters_dict))
125
126
    def load(self):
127
        for image in self.get_images(intensity_only=False):
128
            image.load()
129