DataParser.get_output()   C
last analyzed

Complexity

Conditions 11

Size

Total Lines 26
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 11
eloc 24
nop 2
dl 0
loc 26
rs 5.4
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.data_parser.DataParser.get_output() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
from typing import Optional, List, Sequence, Union
2
3
import torch
4
import numpy as np
5
import nibabel as nib
6
import SimpleITK as sitk
7
8
from ..typing import TypeData
9
from ..data.subject import Subject
10
from ..data.image import Image, ScalarImage
11
from ..data.io import nib_to_sitk, sitk_to_nib
12
13
14
TypeTransformInput = Union[
15
    Subject,
16
    Image,
17
    torch.Tensor,
18
    np.ndarray,
19
    sitk.Image,
20
    dict,
21
    nib.Nifti1Image,
22
]
23
24
25
class DataParser:
26
    def __init__(
27
            self,
28
            data: TypeTransformInput,
29
            keys: Optional[Sequence[str]] = None,
30
            ):
31
        self.data = data
32
        self.keys = keys
33
        self.default_image_name = 'default_image_name'
34
        self.is_tensor = False
35
        self.is_array = False
36
        self.is_dict = False
37
        self.is_image = False
38
        self.is_sitk = False
39
        self.is_nib = False
40
41
    def get_subject(self):
42
        if isinstance(self.data, nib.Nifti1Image):
43
            tensor = self.data.get_fdata(dtype=np.float32)
44
            data = ScalarImage(tensor=tensor, affine=self.data.affine)
45
            subject = self._get_subject_from_image(data)
46
            self.is_nib = True
47
        elif isinstance(self.data, (np.ndarray, torch.Tensor)):
48
            subject = self._parse_tensor(self.data)
49
            self.is_array = isinstance(self.data, np.ndarray)
50
            self.is_tensor = True
51
        elif isinstance(self.data, Image):
52
            subject = self._get_subject_from_image(self.data)
53
            self.is_image = True
54
        elif isinstance(self.data, Subject):
55
            subject = self.data
56
        elif isinstance(self.data, sitk.Image):
57
            subject = self._get_subject_from_sitk_image(self.data)
58
            self.is_sitk = True
59
        elif isinstance(self.data, dict):  # e.g. Eisen or MONAI dicts
60
            if self.keys is None:
61
                message = (
62
                    'If the input is a dictionary, a value for "include" must'
63
                    ' be specified when instantiating the transform. See the'
64
                    ' docs for Transform:'
65
                    ' https://torchio.readthedocs.io/transforms/transforms.html#torchio.transforms.Transform'  # noqa: E501
66
                )
67
                raise RuntimeError(message)
68
            subject = self._get_subject_from_dict(self.data, self.keys)
69
            self.is_dict = True
70
        else:
71
            raise ValueError(f'Input type not recognized: {type(self.data)}')
72
        self._parse_subject(subject)
73
        return subject
74
75
    def get_output(self, transformed):
76
        if self.is_tensor or self.is_sitk:
77
            image = transformed[self.default_image_name]
78
            transformed = image.data
79
            if self.is_array:
80
                transformed = transformed.numpy()
81
            elif self.is_sitk:
82
                transformed = nib_to_sitk(image.data, image.affine)
83
        elif self.is_image:
84
            transformed = transformed[self.default_image_name]
85
        elif self.is_dict:
86
            transformed = dict(transformed)
87
            for key, value in transformed.items():
88
                if isinstance(value, Image):
89
                    transformed[key] = value.data
90
        elif self.is_nib:
91
            image = transformed[self.default_image_name]
92
            data = image.data
93
            if len(data) > 1:
94
                message = (
95
                    'Multichannel images not supported for input of type'
96
                    ' nibabel.nifti.Nifti1Image'
97
                )
98
                raise RuntimeError(message)
99
            transformed = nib.Nifti1Image(data[0].numpy(), image.affine)
100
        return transformed
101
102
    @staticmethod
103
    def _parse_subject(subject: Subject) -> None:
104
        if not isinstance(subject, Subject):
105
            message = (
106
                'Input to a transform must be a tensor or an instance'
107
                f' of torchio.Subject, not "{type(subject)}"'
108
            )
109
            raise RuntimeError(message)
110
111
    def _parse_tensor(self, data: TypeData) -> Subject:
112
        if data.ndim != 4:
113
            message = (
114
                'The input must be a 4D tensor with dimensions'
115
                f' (channels, x, y, z) but it has shape {tuple(data.shape)}.'
116
                ' Tips: if it is a volume, please add the channels dimension;'
117
                ' if it is 2D, also add a dimension of size 1 for the z axis'
118
            )
119
            raise ValueError(message)
120
        return self._get_subject_from_tensor(data)
121
122
    def _get_subject_from_tensor(self, tensor: torch.Tensor) -> Subject:
123
        image = ScalarImage(tensor=tensor)
124
        return self._get_subject_from_image(image)
125
126
    def _get_subject_from_image(self, image: Image) -> Subject:
127
        subject = Subject({self.default_image_name: image})
128
        return subject
129
130
    @staticmethod
131
    def _get_subject_from_dict(
132
            data: dict,
133
            image_keys: List[str],
134
            ) -> Subject:
135
        subject_dict = {}
136
        for key, value in data.items():
137
            if key in image_keys:
138
                value = ScalarImage(tensor=value)
139
            subject_dict[key] = value
140
        return Subject(subject_dict)
141
142
    def _get_subject_from_sitk_image(self, image):
143
        tensor, affine = sitk_to_nib(image)
144
        image = ScalarImage(tensor=tensor, affine=affine)
145
        return self._get_subject_from_image(image)
146