Completed
Pull Request — master (#353)
by Fernando
118:39 queued 117:31
created

DataParser.get_subject()   B

Complexity

Conditions 8

Size

Total Lines 31
Code Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
eloc 28
nop 1
dl 0
loc 31
rs 7.3333
c 0
b 0
f 0
1
from typing import Optional, List, Sequence, Union
2
3
import torch
4
import numpy as np
5
import nibabel as nib
6
import SimpleITK as sitk
7
8
from .. import TypeData, DATA, AFFINE
9
from ..data.subject import Subject
10
from ..data.image import Image, ScalarImage
11
from ..utils import nib_to_sitk, sitk_to_nib
12
13
14
TypeTransformInput = Union[
15
    Subject,
16
    Image,
17
    torch.Tensor,
18
    np.ndarray,
19
    sitk.Image,
20
    dict,
21
    nib.Nifti1Image,
22
]
23
24
25
class DataParser:
26
    def __init__(
27
            self,
28
            data: TypeTransformInput,
29
            keys: Optional[Sequence[str]] = None,
30
            ):
31
        self.data = data
32
        self.keys = keys
33
        self.default_image_name = 'default_image_name'
34
        self.is_tensor = False
35
        self.is_array = False
36
        self.is_dict = False
37
        self.is_image = False
38
        self.is_sitk = False
39
        self.is_nib = False
40
41
    def get_subject(self):
42
        if isinstance(self.data, nib.Nifti1Image):
43
            tensor = self.data.get_fdata(dtype=np.float32)
44
            data = ScalarImage(tensor=tensor, affine=self.data.affine)
45
            subject = self._get_subject_from_image(data)
46
            self.is_nib = True
47
        elif isinstance(self.data, (np.ndarray, torch.Tensor)):
48
            subject = self._parse_tensor(self.data)
49
            self.is_array = isinstance(self.data, np.ndarray)
50
            self.is_tensor = True
51
        elif isinstance(self.data, Image):
52
            subject = self._get_subject_from_image(self.data)
53
            self.is_image = True
54
        elif isinstance(self.data, Subject):
55
            subject = self.data
56
        elif isinstance(self.data, sitk.Image):
57
            subject = self._get_subject_from_sitk_image(self.data)
58
            self.is_sitk = True
59
        elif isinstance(self.data, dict):  # e.g. Eisen or MONAI dicts
60
            if self.keys is None:
61
                message = (
62
                    'If input is a dictionary, a value for "keys" must be'
63
                    ' specified when instantiating the transform'
64
                )
65
                raise RuntimeError(message)
66
            subject = self._get_subject_from_dict(self.data, self.keys)
67
            self.is_dict = True
68
        else:
69
            raise ValueError(f'Input type not recognized: {type(self.data)}')
70
        self._parse_subject(subject)
71
        return subject
72
73
    def get_output(self, transformed):
74
        if self.is_tensor or self.is_sitk:
75
            image = transformed[self.default_image_name]
76
            transformed = image[DATA]
77
            if self.is_array:
78
                transformed = transformed.numpy()
79
            elif self.is_sitk:
80
                transformed = nib_to_sitk(image[DATA], image[AFFINE])
81
        elif self.is_image:
82
            transformed = transformed[self.default_image_name]
83
        elif self.is_dict:
84
            transformed = dict(transformed)
85
            for key, value in transformed.items():
86
                if isinstance(value, Image):
87
                    transformed[key] = value.data
88
        elif self.is_nib:
89
            image = transformed[self.default_image_name]
90
            data = image[DATA]
91
            if len(data) > 1:
92
                message = (
93
                    'Multichannel images not supported for input of type'
94
                    ' nibabel.nifti.Nifti1Image'
95
                )
96
                raise RuntimeError(message)
97
            transformed = nib.Nifti1Image(data[0].numpy(), image[AFFINE])
98
        return transformed
99
100
    @staticmethod
101
    def _parse_subject(subject: Subject) -> None:
102
        if not isinstance(subject, Subject):
103
            message = (
104
                'Input to a transform must be a tensor or an instance'
105
                f' of torchio.Subject, not "{type(subject)}"'
106
            )
107
            raise RuntimeError(message)
108
109
    def _parse_tensor(self, data: TypeData) -> Subject:
110
        if data.ndim != 4:
111
            message = (
112
                'The input must be a 4D tensor with dimensions'
113
                f' (channels, x, y, z) but it has shape {tuple(data.shape)}'
114
            )
115
            raise ValueError(message)
116
        return self._get_subject_from_tensor(data)
117
118
    def _get_subject_from_tensor(self, tensor: torch.Tensor) -> Subject:
119
        image = ScalarImage(tensor=tensor)
120
        return self._get_subject_from_image(image)
121
122
    def _get_subject_from_image(self, image: Image) -> Subject:
123
        subject = Subject({self.default_image_name: image})
124
        return subject
125
126
    @staticmethod
127
    def _get_subject_from_dict(
128
            data: dict,
129
            image_keys: List[str],
130
            ) -> Subject:
131
        subject_dict = {}
132
        for key, value in data.items():
133
            if key in image_keys:
134
                value = ScalarImage(tensor=value)
135
            subject_dict[key] = value
136
        return Subject(subject_dict)
137
138
    def _get_subject_from_sitk_image(self, image):
139
        tensor, affine = sitk_to_nib(image)
140
        image = ScalarImage(tensor=tensor, affine=affine)
141
        return self._get_subject_from_image(image)
142