Passed
Pull Request — master (#353)
by Fernando
01:07
created

Transform.parse_params()   C

Complexity

Conditions 10

Size

Total Lines 18
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 15
nop 6
dl 0
loc 18
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.transform.Transform.parse_params() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import copy
2
import numbers
3
from abc import ABC, abstractmethod
4
from typing import Optional, Union, Tuple, List
5
6
import torch
7
import numpy as np
8
import nibabel as nib
9
import SimpleITK as sitk
10
11
from .. import TypeData, DATA, AFFINE, TypeNumber
12
from ..data.subject import Subject
13
from ..data.image import Image, ScalarImage
14
from ..utils import nib_to_sitk, sitk_to_nib, to_tuple
15
from .interpolation import Interpolation
16
17
18
TypeTransformInput = Union[
19
    Subject,
20
    Image,
21
    torch.Tensor,
22
    np.ndarray,
23
    sitk.Image,
24
    dict,
25
]
26
27
28
class Transform(ABC):
29
    """Abstract class for all TorchIO transforms.
30
31
    All subclasses should overwrite
32
    :py:meth:`torchio.tranforms.Transform.apply_transform`,
33
    which takes data, applies some transformation and returns the result.
34
35
    The input can be an instance of
36
    :py:class:`torchio.Subject`,
37
    :py:class:`torchio.Image`,
38
    :py:class:`numpy.ndarray`,
39
    :py:class:`torch.Tensor`,
40
    :py:class:`SimpleITK.image`,
41
    or a Python dictionary.
42
43
    Args:
44
        p: Probability that this transform will be applied.
45
        copy: Make a shallow copy of the input before applying the transform.
46
        keys: Mandatory if the input is a Python dictionary. The transform will
47
            be applied only to the data in each key.
48
    """
49
    def __init__(
50
            self,
51
            p: float = 1,
52
            copy: bool = True,
53
            keys: Optional[List[str]] = None,
54
            ):
55
        self.probability = self.parse_probability(p)
56
        self.copy = copy
57
        self.keys = keys
58
59
    def __call__(
60
            self,
61
            data: TypeTransformInput,
62
            ) -> TypeTransformInput:
63
        """Transform data and return a result of the same type.
64
65
        Args:
66
            data: Instance of :py:class:`~torchio.Subject`, 4D
67
                :py:class:`torch.Tensor` or 4D NumPy array with dimensions
68
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
69
                and :math:`W, H, D` are the spatial dimensions. If the input is
70
                a tensor, the affine matrix is an identity and a tensor will be
71
                also returned.
72
        """
73
        if torch.rand(1).item() > self.probability:
74
            return data
75
76
        data_parser = DataToSubject(data, keys=self.keys)
77
        subject = data_parser.get_subject()
78
79
        if self.copy:
80
            subject = copy.copy(subject)
81
82
        with np.errstate(all='raise'):
83
            transformed = self.apply_transform(subject)
84
85
        self.add_transform_to_subject_history(transformed)
86
87
        for image in transformed.get_images(intensity_only=False):
88
            ndim = image[DATA].ndim
89
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
90
91
        output = data_parser.get_output(transformed)
92
93
        return output
94
95
    @abstractmethod
96
    def apply_transform(self, subject: Subject):
97
        raise NotImplementedError
98
99
    def add_transform_to_subject_history(self, subject):
100
        from .augmentation import RandomTransform
101
        from .augmentation.composition import Compose, OneOf
102
        if not isinstance(self, (RandomTransform, Compose, OneOf)):
103
            subject.add_transform(self, self.get_arguments())
104
105
    @staticmethod
106
    def to_range(n, around):
107
        if around is None:
108
            return 0, n
109
        else:
110
            return around - n, around + n
111
112
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
113
        params = to_tuple(params)
114
        if len(params) == 1 or (len(params) == 2 and make_ranges):  # d or (a, b)
115
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
116
        if len(params) == 3 and make_ranges:  # (a, b, c)
117
            items = [self.to_range(n, around) for n in params]
118
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
119
            params = [n for prange in items for n in prange]
120
        if make_ranges and len(params) != 6:
121
            if len(params) != 6:
122
                message = (
123
                    f'If "{name}" is a sequence, it must have length 2, 3 or 6,'
124
                    f' not {len(params)}'
125
                )
126
                raise ValueError(message)
127
            for param_range in zip(params[::2], params[1::2]):
128
                self.parse_range(param_range, name, **kwargs)
129
        return tuple(params)
130
131
    @staticmethod
132
    def parse_range(
133
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
134
            name: str,
135
            min_constraint: TypeNumber = None,
136
            max_constraint: TypeNumber = None,
137
            type_constraint: type = None,
138
            ) -> Tuple[TypeNumber, TypeNumber]:
139
        r"""Adapted from ``torchvision.transforms.RandomRotation``.
140
141
        Args:
142
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
143
                where :math:`n_{min} \leq n_{max}`.
144
                If a single positive number :math:`n` is provided,
145
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
146
            name: Name of the parameter, so that an informative error message
147
                can be printed.
148
            min_constraint: Minimal value that :math:`n_{min}` can take,
149
                default is None, i.e. there is no minimal value.
150
            max_constraint: Maximal value that :math:`n_{max}` can take,
151
                default is None, i.e. there is no maximal value.
152
            type_constraint: Precise type that :math:`n_{max}` and
153
                :math:`n_{min}` must take.
154
155
        Returns:
156
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
157
158
        Raises:
159
            ValueError: if :attr:`nums_range` is negative
160
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
161
            ValueError: if :math:`n_{max} \lt n_{min}`
162
            ValueError: if :attr:`min_constraint` is not None and
163
                :math:`n_{min}` is smaller than :attr:`min_constraint`
164
            ValueError: if :attr:`max_constraint` is not None and
165
                :math:`n_{max}` is greater than :attr:`max_constraint`
166
            ValueError: if :attr:`type_constraint` is not None and
167
                :math:`n_{max}` and :math:`n_{max}` are not of type
168
                :attr:`type_constraint`.
169
        """
170
        if isinstance(nums_range, numbers.Number):  # single number given
171
            if nums_range < 0:
172
                raise ValueError(
173
                    f'If {name} is a single number,'
174
                    f' it must be positive, not {nums_range}')
175
            if min_constraint is not None and nums_range < min_constraint:
176
                raise ValueError(
177
                    f'If {name} is a single number, it must be greater'
178
                    f' than {min_constraint}, not {nums_range}'
179
                )
180
            if max_constraint is not None and nums_range > max_constraint:
181
                raise ValueError(
182
                    f'If {name} is a single number, it must be smaller'
183
                    f' than {max_constraint}, not {nums_range}'
184
                )
185
            if type_constraint is not None:
186
                if not isinstance(nums_range, type_constraint):
187
                    raise ValueError(
188
                        f'If {name} is a single number, it must be of'
189
                        f' type {type_constraint}, not {nums_range}'
190
                    )
191
            min_range = -nums_range if min_constraint is None else nums_range
192
            return (min_range, nums_range)
193
194
        try:
195
            min_value, max_value = nums_range
196
        except (TypeError, ValueError):
197
            raise ValueError(
198
                f'If {name} is not a single number, it must be'
199
                f' a sequence of len 2, not {nums_range}'
200
            )
201
202
        min_is_number = isinstance(min_value, numbers.Number)
203
        max_is_number = isinstance(max_value, numbers.Number)
204
        if not min_is_number or not max_is_number:
205
            message = (
206
                f'{name} values must be numbers, not {nums_range}')
207
            raise ValueError(message)
208
209
        if min_value > max_value:
210
            raise ValueError(
211
                f'If {name} is a sequence, the second value must be'
212
                f' equal or greater than the first, but it is {nums_range}')
213
214
        if min_constraint is not None and min_value < min_constraint:
215
            raise ValueError(
216
                f'If {name} is a sequence, the first value must be greater'
217
                f' than {min_constraint}, but it is {min_value}'
218
            )
219
220
        if max_constraint is not None and max_value > max_constraint:
221
            raise ValueError(
222
                f'If {name} is a sequence, the second value must be smaller'
223
                f' than {max_constraint}, but it is {max_value}'
224
            )
225
226
        if type_constraint is not None:
227
            min_type_ok = isinstance(min_value, type_constraint)
228
            max_type_ok = isinstance(max_value, type_constraint)
229
            if not min_type_ok or not max_type_ok:
230
                raise ValueError(
231
                    f'If "{name}" is a sequence, its values must be of'
232
                    f' type "{type_constraint}", not "{type(nums_range)}"'
233
                )
234
        return nums_range
235
236
    @staticmethod
237
    def parse_interpolation(interpolation: str) -> str:
238
        interpolation = interpolation.lower()
239
        is_string = isinstance(interpolation, str)
240
        supported_values = [key.name.lower() for key in Interpolation]
241
        is_supported = interpolation.lower() in supported_values
242
        if is_string and is_supported:
243
            return interpolation
244
        message = (
245
            f'Interpolation "{interpolation}" of type {type(interpolation)}'
246
            f' must be a string among the supported values: {supported_values}'
247
        )
248
        raise TypeError(message)
249
250
    @staticmethod
251
    def parse_probability(probability: float) -> float:
252
        is_number = isinstance(probability, numbers.Number)
253
        if not (is_number and 0 <= probability <= 1):
254
            message = (
255
                'Probability must be a number in [0, 1],'
256
                f' not {probability}'
257
            )
258
            raise ValueError(message)
259
        return probability
260
261
    @staticmethod
262
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
263
        return nib_to_sitk(data, affine)
264
265
    @staticmethod
266
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
267
        return sitk_to_nib(image)
268
269
    @property
270
    def name(self):
271
        return self.__class__.__name__
272
273
    def is_invertible(self):
274
        return hasattr(self, 'invert_transform')
275
276
    def inverse(self):
277
        if not self.is_invertible():
278
            raise RuntimeError(f'{self.name} is not invertible')
279
        new = copy.deepcopy(self)
280
        new.invert_transform = not self.invert_transform
281
        return new
282
283
284
class DataToSubject:
285
    def __init__(
286
            self,
287
            data: TypeTransformInput,
288
            keys: Optional[List[str]] = None,
289
            ):
290
        self.data = data
291
        self.keys = keys
292
        self.default_image_name = 'default_image_name'
293
        self.is_tensor = False
294
        self.is_array = False
295
        self.is_dict = False
296
        self.is_image = False
297
        self.is_sitk = False
298
        self.is_nib = False
299
300
    def get_subject(self):
301
        if isinstance(self.data, nib.Nifti1Image):
302
            tensor = self.data.get_fdata(dtype=np.float32)
303
            data = ScalarImage(tensor=tensor, affine=self.data.affine)
304
            subject = self._get_subject_from_image(data)
305
            self.is_nib = True
306
        elif isinstance(self.data, (np.ndarray, torch.Tensor)):
307
            subject = self._parse_tensor(self.data)
308
            self.is_array = isinstance(self.data, np.ndarray)
309
            self.is_tensor = True
310
        elif isinstance(self.data, Image):
311
            subject = self._get_subject_from_image(self.data)
312
            self.is_image = True
313
        elif isinstance(self.data, Subject):
314
            subject = self.data
315
        elif isinstance(self.data, sitk.Image):
316
            subject = self._get_subject_from_sitk_image(self.data)
317
            self.is_sitk = True
318
        elif isinstance(self.data, dict):  # e.g. Eisen or MONAI dicts
319
            if self.keys is None:
320
                message = (
321
                    'If input is a dictionary, a value for "keys" must be'
322
                    ' specified when instantiating the transform'
323
                )
324
                raise RuntimeError(message)
325
            subject = self._get_subject_from_dict(self.data, self.keys)
326
            self.is_dict = True
327
        else:
328
            raise ValueError(f'Input type not recognized: {type(self.data)}')
329
        self._parse_subject(subject)
330
        return subject
331
332
    def get_output(self, transformed):
333
        if self.is_tensor or self.is_sitk:
334
            image = transformed[self.default_image_name]
335
            transformed = image[DATA]
336
            if self.is_array:
337
                transformed = transformed.numpy()
338
            elif self.is_sitk:
339
                transformed = nib_to_sitk(image[DATA], image[AFFINE])
340
        elif self.is_image:
341
            transformed = transformed[self.default_image_name]
342
        elif self.is_dict:
343
            transformed = dict(transformed)
344
            for key, value in transformed.items():
345
                if isinstance(value, Image):
346
                    transformed[key] = value.data
347
        elif self.is_nib:
348
            image = transformed[self.default_image_name]
349
            data = image[DATA]
350
            if len(data) > 1:
351
                message = (
352
                    'Multichannel images not supported for input of type'
353
                    ' nibabel.nifti.Nifti1Image'
354
                )
355
                raise RuntimeError(message)
356
            transformed = nib.Nifti1Image(data[0].numpy(), image[AFFINE])
357
        return transformed
358
359
    @staticmethod
360
    def _parse_subject(subject: Subject) -> None:
361
        if not isinstance(subject, Subject):
362
            message = (
363
                'Input to a transform must be a tensor or an instance'
364
                f' of torchio.Subject, not "{type(subject)}"'
365
            )
366
            raise RuntimeError(message)
367
368
    def _parse_tensor(self, data: TypeData) -> Subject:
369
        if data.ndim != 4:
370
            message = (
371
                'The input must be a 4D tensor with dimensions'
372
                f' (channels, x, y, z) but it has shape {tuple(data.shape)}'
373
            )
374
            raise ValueError(message)
375
        return self._get_subject_from_tensor(data)
376
377
    def _get_subject_from_tensor(self, tensor: torch.Tensor) -> Subject:
378
        image = ScalarImage(tensor=tensor)
379
        return self._get_subject_from_image(image)
380
381
    def _get_subject_from_image(self, image: Image) -> Subject:
382
        subject = Subject({self.default_image_name: image})
383
        return subject
384
385
    @staticmethod
386
    def _get_subject_from_dict(
387
            data: dict,
388
            image_keys: List[str],
389
            ) -> Subject:
390
        subject_dict = {}
391
        for key, value in data.items():
392
            if key in image_keys:
393
                value = ScalarImage(tensor=value)
394
            subject_dict[key] = value
395
        return Subject(subject_dict)
396
397
    def _get_subject_from_sitk_image(self, image):
398
        tensor, affine = sitk_to_nib(image)
399
        image = ScalarImage(tensor=tensor, affine=affine)
400
        return self._get_subject_from_image(image)
401