|
1
|
|
|
import pprint |
|
2
|
|
|
from typing import ( |
|
3
|
|
|
Any, |
|
4
|
|
|
Dict, |
|
5
|
|
|
List, |
|
6
|
|
|
Tuple, |
|
7
|
|
|
) |
|
8
|
|
|
from ..torchio import DATA |
|
9
|
|
|
from .image import Image |
|
10
|
|
|
|
|
11
|
|
|
|
|
12
|
|
|
class Subject(dict): |
|
13
|
|
|
"""Class to store information about the images corresponding to a subject. |
|
14
|
|
|
|
|
15
|
|
|
Args: |
|
16
|
|
|
*args: If provided, a dictionary of items. |
|
17
|
|
|
**kwargs: Items that will be added to the subject sample. |
|
18
|
|
|
|
|
19
|
|
|
Example: |
|
20
|
|
|
|
|
21
|
|
|
>>> import torchio |
|
22
|
|
|
>>> from torchio import Image, Subject |
|
23
|
|
|
>>> # One way: |
|
24
|
|
|
>>> subject = Subject( |
|
25
|
|
|
... one_image=Image('path_to_image.nii.gz, torchio.INTENSITY), |
|
26
|
|
|
... a_segmentation=Image('path_to_seg.nii.gz, torchio.LABEL), |
|
27
|
|
|
... age=45, |
|
28
|
|
|
... name='John Doe', |
|
29
|
|
|
... hospital='Hospital Juan Negrín', |
|
30
|
|
|
... ) |
|
31
|
|
|
>>> # If you want to create the mapping before, or have spaces in the keys: |
|
32
|
|
|
>>> subject_dict = { |
|
33
|
|
|
... 'one image': Image('path_to_image.nii.gz, torchio.INTENSITY), |
|
34
|
|
|
... 'a segmentation': Image('path_to_seg.nii.gz, torchio.LABEL), |
|
35
|
|
|
... 'age': 45, |
|
36
|
|
|
... 'name': 'John Doe', |
|
37
|
|
|
... 'hospital': 'Hospital Juan Negrín', |
|
38
|
|
|
... } |
|
39
|
|
|
>>> Subject(subject_dict) |
|
40
|
|
|
|
|
41
|
|
|
""" |
|
42
|
|
|
|
|
43
|
|
|
def __init__(self, *args, **kwargs: Dict[str, Any]): |
|
44
|
|
|
if args: |
|
45
|
|
|
if len(args) == 1 and isinstance(args[0], dict): |
|
46
|
|
|
kwargs.update(args[0]) |
|
47
|
|
|
else: |
|
48
|
|
|
message = ( |
|
49
|
|
|
'Only one dictionary as positional argument is allowed') |
|
50
|
|
|
raise ValueError(message) |
|
51
|
|
|
super().__init__(**kwargs) |
|
52
|
|
|
self.images = [ |
|
53
|
|
|
(k, v) for (k, v) in self.items() |
|
54
|
|
|
if isinstance(v, Image) |
|
55
|
|
|
] |
|
56
|
|
|
self._parse_images(self.images) |
|
57
|
|
|
self.is_sample = False # set to True by ImagesDataset |
|
58
|
|
|
self.history = [] |
|
59
|
|
|
|
|
60
|
|
|
def __repr__(self): |
|
61
|
|
|
string = ( |
|
62
|
|
|
f'{self.__class__.__name__}' |
|
63
|
|
|
f'(Keys: {tuple(self.keys())}; images: {len(self.images)})' |
|
64
|
|
|
) |
|
65
|
|
|
return string |
|
66
|
|
|
|
|
67
|
|
|
@staticmethod |
|
68
|
|
|
def _parse_images(images: List[Tuple[str, Image]]) -> None: |
|
69
|
|
|
# Check that it's not empty |
|
70
|
|
|
if not images: |
|
71
|
|
|
raise ValueError('A subject without images cannot be created') |
|
72
|
|
|
|
|
73
|
|
|
def check_consistent_shape(self) -> None: |
|
74
|
|
|
shapes_dict = {} |
|
75
|
|
|
for key, image in self.items(): |
|
76
|
|
|
if not isinstance(image, Image) or not image.is_sample: |
|
77
|
|
|
continue |
|
78
|
|
|
shapes_dict[key] = image[DATA].shape |
|
79
|
|
|
num_unique_shapes = len(set(shapes_dict.values())) |
|
80
|
|
|
if num_unique_shapes > 1: |
|
81
|
|
|
message = ( |
|
82
|
|
|
'Images in sample have inconsistent shapes:' |
|
83
|
|
|
f'\n{pprint.pformat(shapes_dict)}' |
|
84
|
|
|
) |
|
85
|
|
|
raise ValueError(message) |
|
86
|
|
|
|
|
87
|
|
|
def add_transform(self, transform, parameters_dict): |
|
88
|
|
|
self.history.append((transform.name, parameters_dict)) |
|
89
|
|
|
|