Passed
Push — master ( b06930...3ddbe5 )
by Fernando
03:55
created

Transform.parse_subject()   A

Complexity

Conditions 2

Size

Total Lines 8
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 6
nop 1
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
import copy
2
import numbers
3
import warnings
4
from abc import ABC, abstractmethod
5
from typing import Optional, Union, Tuple, List
6
7
import torch
8
import numpy as np
9
import nibabel as nib
10
import SimpleITK as sitk
11
12
from .. import TypeData, DATA, AFFINE, TypeNumber
13
from ..data.subject import Subject
14
from ..data.image import Image, ScalarImage
15
from ..utils import nib_to_sitk, sitk_to_nib
16
from .interpolation import Interpolation
17
18
19
TypeTransformInput = Union[Subject, torch.Tensor, np.ndarray, dict, sitk.Image]
20
21
22
class Transform(ABC):
23
    """Abstract class for all TorchIO transforms.
24
25
    All subclasses should overwrite
26
    :py:meth:`torchio.tranforms.Transform.apply_transform`,
27
    which takes data, applies some transformation and returns the result.
28
29
    The input can be an instance of
30
    :py:class:`torchio.Subject`,
31
    :py:class:`torchio.Image`,
32
    :py:class:`numpy.ndarray`,
33
    :py:class:`torch.Tensor`,
34
    :py:class:`SimpleITK.image`,
35
    or a Python dictionary.
36
37
    Args:
38
        p: Probability that this transform will be applied.
39
        copy: Make a shallow copy of the input before applying the transform.
40
        keys: Mandatory if the input is a Python dictionary. The transform will
41
            be applied only to the data in each key.
42
    """
43
    def __init__(
44
            self,
45
            p: float = 1,
46
            copy: bool = True,
47
            keys: Optional[List[str]] = None,
48
            ):
49
        self.probability = self.parse_probability(p)
50
        self.copy = copy
51
        self.keys = keys
52
        self.default_image_name = 'default_image_name'
53
54
    def __call__(self, data: TypeTransformInput) -> TypeTransformInput:
55
        """Transform data and return a result of the same type.
56
57
        Args:
58
            data: Instance of :py:class:`~torchio.Subject`, 4D
59
                :py:class:`torch.Tensor` or 4D NumPy array with dimensions
60
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
61
                and :math:`W, H, D` are the spatial dimensions. If the input is
62
                a tensor, the affine matrix is an identity and a tensor will be
63
                also returned.
64
        """
65
        if torch.rand(1).item() > self.probability:
66
            return data
67
68
        is_tensor = is_array = is_dict = is_image = is_sitk = is_nib = False
69
70
        if isinstance(data, nib.Nifti1Image):
71
            tensor = data.get_fdata(dtype=np.float32)
72
            data = ScalarImage(tensor=tensor, affine=data.affine)
73
            subject = self._get_subject_from_image(data)
74
            is_nib = True
75
        elif isinstance(data, (np.ndarray, torch.Tensor)):
76
            subject = self.parse_tensor(data)
77
            is_array = isinstance(data, np.ndarray)
78
            is_tensor = True
79
        elif isinstance(data, Image):
80
            subject = self._get_subject_from_image(data)
81
            is_image = True
82
        elif isinstance(data, Subject):
83
            subject = data
84
        elif isinstance(data, sitk.Image):
85
            subject = self._get_subject_from_sitk_image(data)
86
            is_sitk = True
87
        elif isinstance(data, dict):  # e.g. Eisen or MONAI dicts
88
            if self.keys is None:
89
                message = (
90
                    'If input is a dictionary, a value for "keys" must be'
91
                    ' specified when instantiating the transform'
92
                )
93
                raise RuntimeError(message)
94
            subject = self._get_subject_from_dict(data, self.keys)
95
            is_dict = True
96
        else:
97
            raise ValueError(f'Input type not recognized: {type(data)}')
98
        self.parse_subject(subject)
99
100
        if self.copy:
101
            subject = copy.copy(subject)
102
103
        with np.errstate(all='raise'):
104
            transformed = self.apply_transform(subject)
105
106
        for image in transformed.get_images(intensity_only=False):
107
            ndim = image[DATA].ndim
108
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
109
110
        if is_tensor or is_sitk:
111
            image = transformed[self.default_image_name]
112
            transformed = image[DATA]
113
            if is_array:
114
                transformed = transformed.numpy()
115
            elif is_sitk:
116
                transformed = nib_to_sitk(image[DATA], image[AFFINE])
117
        elif is_image:
118
            transformed = transformed[self.default_image_name]
119
        elif is_dict:
120
            transformed = dict(transformed)
121
            for key, value in transformed.items():
122
                if isinstance(value, Image):
123
                    transformed[key] = value.data
124
        elif is_nib:
125
            image = transformed[self.default_image_name]
126
            data = image[DATA]
127
            if len(data) > 1:
128
                message = (
129
                    'Multichannel images not supported for input of type'
130
                    ' nibabel.nifti.Nifti1Image'
131
                )
132
                raise RuntimeError(message)
133
            transformed = nib.Nifti1Image(data[0].numpy(), image[AFFINE])
134
        return transformed
135
136
    @abstractmethod
137
    def apply_transform(self, subject: Subject):
138
        raise NotImplementedError
139
140
    @staticmethod
141
    def parse_range(
142
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
143
            name: str,
144
            min_constraint: TypeNumber = None,
145
            max_constraint: TypeNumber = None,
146
            type_constraint: type = None,
147
            ) -> Tuple[TypeNumber, TypeNumber]:
148
        r"""Adapted from ``torchvision.transforms.RandomRotation``.
149
150
        Args:
151
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
152
                where :math:`n_{min} \leq n_{max}`.
153
                If a single positive number :math:`n` is provided,
154
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
155
            name: Name of the parameter, so that an informative error message
156
                can be printed.
157
            min_constraint: Minimal value that :math:`n_{min}` can take,
158
                default is None, i.e. there is no minimal value.
159
            max_constraint: Maximal value that :math:`n_{max}` can take,
160
                default is None, i.e. there is no maximal value.
161
            type_constraint: Precise type that :math:`n_{max}` and
162
                :math:`n_{min}` must take.
163
164
        Returns:
165
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
166
167
        Raises:
168
            ValueError: if :attr:`nums_range` is negative
169
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
170
            ValueError: if :math:`n_{max} \lt n_{min}`
171
            ValueError: if :attr:`min_constraint` is not None and
172
                :math:`n_{min}` is smaller than :attr:`min_constraint`
173
            ValueError: if :attr:`max_constraint` is not None and
174
                :math:`n_{max}` is greater than :attr:`max_constraint`
175
            ValueError: if :attr:`type_constraint` is not None and
176
                :math:`n_{max}` and :math:`n_{max}` are not of type
177
                :attr:`type_constraint`.
178
        """
179
        if isinstance(nums_range, numbers.Number):
180
            if nums_range < 0:
181
                raise ValueError(
182
                    f'If {name} is a single number,'
183
                    f' it must be positive, not {nums_range}')
184
            if min_constraint is not None and nums_range < min_constraint:
185
                raise ValueError(
186
                    f'If {name} is a single number, it must be greater'
187
                    f' than {min_constraint}, not {nums_range}'
188
                )
189
            if max_constraint is not None and nums_range > max_constraint:
190
                raise ValueError(
191
                    f'If {name} is a single number, it must be smaller'
192
                    f' than {max_constraint}, not {nums_range}'
193
                )
194
            if type_constraint is not None:
195
                if not isinstance(nums_range, type_constraint):
196
                    raise ValueError(
197
                        f'If {name} is a single number, it must be of'
198
                        f' type {type_constraint}, not {nums_range}'
199
                    )
200
            min_range = -nums_range if min_constraint is None else nums_range
201
            return (min_range, nums_range)
202
203
        try:
204
            min_degree, max_degree = nums_range
205
        except (TypeError, ValueError):
206
            raise ValueError(
207
                f'If {name} is not a single number, it must be'
208
                f' a sequence of len 2, not {nums_range}'
209
            )
210
211
        min_is_number = isinstance(min_degree, numbers.Number)
212
        max_is_number = isinstance(max_degree, numbers.Number)
213
        if not min_is_number or not max_is_number:
214
            message = (
215
                f'{name} values must be numbers, not {nums_range}')
216
            raise ValueError(message)
217
218
        if min_degree > max_degree:
219
            raise ValueError(
220
                f'If {name} is a sequence, the second value must be'
221
                f' equal or greater than the first, but it is {nums_range}')
222
223
        if min_constraint is not None and min_degree < min_constraint:
224
            raise ValueError(
225
                f'If {name} is a sequence, the first value must be greater'
226
                f' than {min_constraint}, but it is {min_degree}'
227
            )
228
229
        if max_constraint is not None and max_degree > max_constraint:
230
            raise ValueError(
231
                f'If {name} is a sequence, the second value must be smaller'
232
                f' than {max_constraint}, but it is {max_degree}'
233
            )
234
235
        if type_constraint is not None:
236
            min_type_ok = isinstance(min_degree, type_constraint)
237
            max_type_ok = isinstance(max_degree, type_constraint)
238
            if not min_type_ok or not max_type_ok:
239
                raise ValueError(
240
                    f'If "{name}" is a sequence, its values must be of'
241
                    f' type "{type_constraint}", not "{type(nums_range)}"'
242
                )
243
        return nums_range
244
245
    @staticmethod
246
    def parse_probability(probability: float) -> float:
247
        is_number = isinstance(probability, numbers.Number)
248
        if not (is_number and 0 <= probability <= 1):
249
            message = (
250
                'Probability must be a number in [0, 1],'
251
                f' not {probability}'
252
            )
253
            raise ValueError(message)
254
        return probability
255
256
    @staticmethod
257
    def parse_subject(subject: Subject) -> None:
258
        if not isinstance(subject, Subject):
259
            message = (
260
                'Input to a transform must be a tensor or an instance'
261
                f' of torchio.Subject, not "{type(subject)}"'
262
            )
263
            raise RuntimeError(message)
264
265
    def parse_tensor(self, data: TypeData) -> Subject:
266
        if data.ndim != 4:
267
            message = (
268
                'The input must be a 4D tensor with dimensions'
269
                f' (channels, x, y, z) but it has shape {tuple(data.shape)}'
270
            )
271
            raise ValueError(message)
272
        return self._get_subject_from_tensor(data)
273
274
    @staticmethod
275
    def parse_interpolation(interpolation: str) -> Interpolation:
276
        if isinstance(interpolation, Interpolation):
277
            message = (
278
                'Interpolation of type torchio.Interpolation'
279
                ' is deprecated, please use a string instead'
280
            )
281
            warnings.warn(message, FutureWarning)
282
        elif isinstance(interpolation, str):
283
            interpolation = interpolation.lower()
284
            supported_values = [key.name.lower() for key in Interpolation]
285
            if interpolation in supported_values:
286
                interpolation = getattr(Interpolation, interpolation.upper())
287
            else:
288
                message = (
289
                    f'Interpolation "{interpolation}" is not among'
290
                    f' the supported values: {supported_values}'
291
                )
292
                raise AttributeError(message)
293
        else:
294
            message = (
295
                'image_interpolation must be a string,'
296
                f' not {type(interpolation)}'
297
            )
298
            raise TypeError(message)
299
        return interpolation
300
301
    def _get_subject_from_tensor(self, tensor: torch.Tensor) -> Subject:
302
        image = ScalarImage(tensor=tensor)
303
        return self._get_subject_from_image(image)
304
305
    def _get_subject_from_image(self, image: Image) -> Subject:
306
        subject = Subject({self.default_image_name: image})
307
        return subject
308
309
    @staticmethod
310
    def _get_subject_from_dict(
311
            data: dict,
312
            image_keys: List[str],
313
            ) -> Subject:
314
        subject_dict = {}
315
        for key, value in data.items():
316
            if key in image_keys:
317
                value = ScalarImage(tensor=value)
318
            subject_dict[key] = value
319
        return Subject(subject_dict)
320
321
    def _get_subject_from_sitk_image(self, image):
322
        tensor, affine = sitk_to_nib(image)
323
        image = ScalarImage(tensor=tensor, affine=affine)
324
        return self._get_subject_from_image(image)
325
326
    @staticmethod
327
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
328
        return nib_to_sitk(data, affine)
329
330
    @staticmethod
331
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
332
        return sitk_to_nib(image)
333
334
    @property
335
    def name(self):
336
        return self.__class__.__name__
337