Passed
Pull Request — master (#403)
by Fernando
01:27
created

torchio.transforms.transform   F

Complexity

Total Complexity 92

Size/Duplication

Total Lines 475
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 312
dl 0
loc 475
rs 2
c 0
b 0
f 0
wmc 92

25 Methods

Rating   Name   Duplication   Size   Complexity  
A Transform.to_range() 0 6 2
A Transform.add_transform_to_subject_history() 0 14 2
A Transform.parse_interpolation() 0 16 4
A Transform.__init__() 0 19 2
C Transform.parse_params() 0 19 9
A Transform.__repr__() 0 10 4
A Transform.__call__() 0 31 5
F Transform._parse_range() 0 104 21
A Transform.name() 0 3 1
A Transform.apply_transform() 0 3 1
A Transform._use_seed() 0 8 1
A Transform.ones() 0 3 1
A Transform.get_sitk_interpolator() 0 3 1
A Transform.sitk_to_nib() 0 3 1
A Transform._get_reproducing_arguments() 0 13 1
A Transform.mean() 0 4 1
A Transform.parse_probability() 0 12 3
B Transform.parse_bounds() 0 28 8
A Transform.parse_include_and_exclude() 0 8 3
A Transform.nib_to_sitk() 0 3 1
A Transform.get_mask_from_bounds() 0 13 1
B Transform.get_mask_from_anatomical_label() 0 27 8
A Transform.inverse() 0 6 2
A Transform.is_invertible() 0 2 1
B Transform.get_mask_from_masking_method() 0 27 8

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.transform often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import copy
2
import numbers
3
import warnings
4
from typing import Union, Tuple
5
from abc import ABC, abstractmethod
6
from contextlib import contextmanager
7
8
import torch
9
import numpy as np
10
import SimpleITK as sitk
11
12
from ..utils import to_tuple
13
from ..data.subject import Subject
14
from ..data.io import nib_to_sitk, sitk_to_nib
15
from ..data.image import LabelMap
16
from ..typing import (
17
    TypeKeys,
18
    TypeData,
19
    TypeNumber,
20
    TypeCallable,
21
    TypeTripletInt,
22
)
23
from .interpolation import Interpolation, get_sitk_interpolator
24
from .data_parser import DataParser, TypeTransformInput
25
26
TypeSixBounds = Tuple[int, int, int, int, int, int]
27
TypeBounds = Union[
28
    int,
29
    TypeTripletInt,
30
    TypeSixBounds,
31
]
32
TypeMaskingMethod = Union[str, TypeCallable, TypeBounds, None]
33
anat_axes = 'Left', 'Right', 'Anterior', 'Posterior', 'Inferior', 'Superior'
34
35
36
class Transform(ABC):
37
    """Abstract class for all TorchIO transforms.
38
39
    All subclasses must overwrite
40
    :meth:`Transform.apply_transform`,
41
    which takes data, applies some transformation and returns the result.
42
43
    The input can be an instance of
44
    :class:`torchio.Subject`,
45
    :class:`torchio.Image`,
46
    :class:`numpy.ndarray`,
47
    :class:`torch.Tensor`,
48
    :class:`SimpleITK.Image`,
49
    or :class:`dict`.
50
51
    Args:
52
        p: Probability that this transform will be applied.
53
        copy: Make a shallow copy of the input before applying the transform.
54
        include: Sequence of strings with the names of the only images to which
55
            the transform will be applied.
56
            Mandatory if the input is a :class:`dict`.
57
        exclude: Sequence of strings with the names of the images to which the
58
            the transform will not be applied, apart from the ones that are
59
            excluded because of the transform type.
60
            For example, if a subject includes an MRI, a CT and a label map,
61
            and the CT is added to the list of exclusions of an intensity
62
            transform such as :class:`~torchio.transforms.RandomBlur`,
63
            the transform will be only applied to the MRI, as the label map is
64
            excluded by default by spatial transforms.
65
    """
66
    def __init__(
67
            self,
68
            p: float = 1,
69
            copy: bool = True,
70
            include: TypeKeys = None,
71
            exclude: TypeKeys = None,
72
            keys: TypeKeys = None,
73
            ):
74
        self.probability = self.parse_probability(p)
75
        self.copy = copy
76
        if keys is not None:
77
            message = (
78
                'The "keys" argument is deprecated and will be removed in the'
79
                ' future. Use "include" instead.'
80
            )
81
            warnings.warn(message, DeprecationWarning)
82
            include = keys
83
        self.include, self.exclude = self.parse_include_and_exclude(
84
            include, exclude)
85
86
    def __call__(
87
            self,
88
            data: TypeTransformInput,
89
            ) -> TypeTransformInput:
90
        """Transform data and return a result of the same type.
91
92
        Args:
93
            data: Instance of 1) :class:`~torchio.Subject`, 4D
94
                :class:`torch.Tensor` or :class:`numpy.ndarray` with dimensions
95
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
96
                and :math:`W, H, D` are the spatial dimensions. If the input is
97
                a tensor, the affine matrix will be set to identity. Other
98
                valid input types are a SimpleITK image, a
99
                :class:`torchio.Image`, a NiBabel Nifti1 image or a
100
                :class:`dict`. The output type is the same as the input type.
101
        """
102
        if torch.rand(1).item() > self.probability:
103
            return data
104
        data_parser = DataParser(data, keys=self.include)
105
        subject = data_parser.get_subject()
106
        if self.copy:
107
            subject = copy.copy(subject)
108
        with np.errstate(all='raise'):
109
            transformed = self.apply_transform(subject)
110
        self.add_transform_to_subject_history(transformed)
111
        for image in transformed.get_images(intensity_only=False):
112
            ndim = image.data.ndim
113
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
114
115
        output = data_parser.get_output(transformed)
116
        return output
117
118
    def __repr__(self):
119
        if hasattr(self, 'args_names'):
120
            names = self.args_names
121
            args_strings = [f'{arg}={getattr(self, arg)}' for arg in names]
122
            if hasattr(self, 'invert_transform') and self.invert_transform:
123
                args_strings.append('invert=True')
124
            args_string = ', '.join(args_strings)
125
            return f'{self.name}({args_string})'
126
        else:
127
            return super().__repr__()
128
129
    @property
130
    def name(self):
131
        return self.__class__.__name__
132
133
    @abstractmethod
134
    def apply_transform(self, subject: Subject):
135
        raise NotImplementedError
136
137
    def add_transform_to_subject_history(self, subject):
138
        from .augmentation import RandomTransform
139
        from . import Compose, OneOf, CropOrPad, EnsureShapeMultiple
140
        from .preprocessing.label import SequentialLabels
141
        call_others = (
142
            RandomTransform,
143
            Compose,
144
            OneOf,
145
            CropOrPad,
146
            EnsureShapeMultiple,
147
            SequentialLabels,
148
        )
149
        if not isinstance(self, call_others):
150
            subject.add_transform(self, self._get_reproducing_arguments())
151
152
    @staticmethod
153
    def to_range(n, around):
154
        if around is None:
155
            return 0, n
156
        else:
157
            return around - n, around + n
158
159
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
160
        params = to_tuple(params)
161
        # d or (a, b)
162
        if len(params) == 1 or (len(params) == 2 and make_ranges):
163
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
164
        if len(params) == 3 and make_ranges:  # (a, b, c)
165
            items = [self.to_range(n, around) for n in params]
166
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
167
            params = [n for prange in items for n in prange]
168
        if make_ranges:
169
            if len(params) != 6:
170
                message = (
171
                    f'If "{name}" is a sequence, it must have length 2, 3 or'
172
                    f' 6, not {len(params)}'
173
                )
174
                raise ValueError(message)
175
            for param_range in zip(params[::2], params[1::2]):
176
                self._parse_range(param_range, name, **kwargs)
177
        return tuple(params)
178
179
    @staticmethod
180
    def _parse_range(
181
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
182
            name: str,
183
            min_constraint: TypeNumber = None,
184
            max_constraint: TypeNumber = None,
185
            type_constraint: type = None,
186
            ) -> Tuple[TypeNumber, TypeNumber]:
187
        r"""Adapted from :class:`torchvision.transforms.RandomRotation`.
188
189
        Args:
190
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
191
                where :math:`n_{min} \leq n_{max}`.
192
                If a single positive number :math:`n` is provided,
193
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
194
            name: Name of the parameter, so that an informative error message
195
                can be printed.
196
            min_constraint: Minimal value that :math:`n_{min}` can take,
197
                default is None, i.e. there is no minimal value.
198
            max_constraint: Maximal value that :math:`n_{max}` can take,
199
                default is None, i.e. there is no maximal value.
200
            type_constraint: Precise type that :math:`n_{max}` and
201
                :math:`n_{min}` must take.
202
203
        Returns:
204
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
205
206
        Raises:
207
            ValueError: if :attr:`nums_range` is negative
208
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
209
            ValueError: if :math:`n_{max} \lt n_{min}`
210
            ValueError: if :attr:`min_constraint` is not None and
211
                :math:`n_{min}` is smaller than :attr:`min_constraint`
212
            ValueError: if :attr:`max_constraint` is not None and
213
                :math:`n_{max}` is greater than :attr:`max_constraint`
214
            ValueError: if :attr:`type_constraint` is not None and
215
                :math:`n_{max}` and :math:`n_{max}` are not of type
216
                :attr:`type_constraint`.
217
        """
218
        if isinstance(nums_range, numbers.Number):  # single number given
219
            if nums_range < 0:
220
                raise ValueError(
221
                    f'If {name} is a single number,'
222
                    f' it must be positive, not {nums_range}')
223
            if min_constraint is not None and nums_range < min_constraint:
224
                raise ValueError(
225
                    f'If {name} is a single number, it must be greater'
226
                    f' than {min_constraint}, not {nums_range}'
227
                )
228
            if max_constraint is not None and nums_range > max_constraint:
229
                raise ValueError(
230
                    f'If {name} is a single number, it must be smaller'
231
                    f' than {max_constraint}, not {nums_range}'
232
                )
233
            if type_constraint is not None:
234
                if not isinstance(nums_range, type_constraint):
235
                    raise ValueError(
236
                        f'If {name} is a single number, it must be of'
237
                        f' type {type_constraint}, not {nums_range}'
238
                    )
239
            min_range = -nums_range if min_constraint is None else nums_range
240
            return (min_range, nums_range)
241
242
        try:
243
            min_value, max_value = nums_range
244
        except (TypeError, ValueError):
245
            raise ValueError(
246
                f'If {name} is not a single number, it must be'
247
                f' a sequence of len 2, not {nums_range}'
248
            )
249
250
        min_is_number = isinstance(min_value, numbers.Number)
251
        max_is_number = isinstance(max_value, numbers.Number)
252
        if not min_is_number or not max_is_number:
253
            message = (
254
                f'{name} values must be numbers, not {nums_range}')
255
            raise ValueError(message)
256
257
        if min_value > max_value:
258
            raise ValueError(
259
                f'If {name} is a sequence, the second value must be'
260
                f' equal or greater than the first, but it is {nums_range}')
261
262
        if min_constraint is not None and min_value < min_constraint:
263
            raise ValueError(
264
                f'If {name} is a sequence, the first value must be greater'
265
                f' than {min_constraint}, but it is {min_value}'
266
            )
267
268
        if max_constraint is not None and max_value > max_constraint:
269
            raise ValueError(
270
                f'If {name} is a sequence, the second value must be smaller'
271
                f' than {max_constraint}, but it is {max_value}'
272
            )
273
274
        if type_constraint is not None:
275
            min_type_ok = isinstance(min_value, type_constraint)
276
            max_type_ok = isinstance(max_value, type_constraint)
277
            if not min_type_ok or not max_type_ok:
278
                raise ValueError(
279
                    f'If "{name}" is a sequence, its values must be of'
280
                    f' type "{type_constraint}", not "{type(nums_range)}"'
281
                )
282
        return nums_range
283
284
    @staticmethod
285
    def parse_interpolation(interpolation: str) -> str:
286
        if not isinstance(interpolation, str):
287
            itype = type(interpolation)
288
            raise TypeError(f'Interpolation must be a string, not {itype}')
289
        interpolation = interpolation.lower()
290
        is_string = isinstance(interpolation, str)
291
        supported_values = [key.name.lower() for key in Interpolation]
292
        is_supported = interpolation.lower() in supported_values
293
        if is_string and is_supported:
294
            return interpolation
295
        message = (
296
            f'Interpolation "{interpolation}" of type {type(interpolation)}'
297
            f' must be a string among the supported values: {supported_values}'
298
        )
299
        raise ValueError(message)
300
301
    @staticmethod
302
    def parse_probability(probability: float) -> float:
303
        thisisaverylnglinewithspellingmistakesandtyposfortestingletseewahtahappensjfsdskfjak = 0
304
        thisisaverylnglinewithspellingmistakesandtyposfortestingletseewahtahappensjfsdskfjak += 1
305
        is_number = isinstance(probability, numbers.Number)
306
        if not (is_number and 0 <= probability <= 1):
307
            message = (
308
                'Probability must be a number in [0, 1],'
309
                f' not {probability}'
310
            )
311
            raise ValueError(message)
312
        return probability
313
314
    @staticmethod
315
    def parse_include_and_exclude(
316
            include: TypeKeys = None,
317
            exclude: TypeKeys = None,
318
            ) -> Tuple[TypeKeys, TypeKeys]:
319
        if include is not None and exclude is not None:
320
            raise ValueError('Include and exclude cannot both be specified')
321
        return include, exclude
322
323
    @staticmethod
324
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
325
        return nib_to_sitk(data, affine)
326
327
    @staticmethod
328
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
329
        return sitk_to_nib(image)
330
331
    def _get_reproducing_arguments(self):
332
        """
333
        Return a dictionary with the arguments that would be necessary to
334
        reproduce the transform exactly.
335
        """
336
        reproducing_arguments = {
337
            'include': self.include,
338
            'exclude': self.exclude,
339
            'copy': self.copy,
340
        }
341
        args_names = {name: getattr(self, name) for name in self.args_names}
342
        reproducing_arguments.update(args_names)
343
        return reproducing_arguments
344
345
    def is_invertible(self):
346
        return hasattr(self, 'invert_transform')
347
348
    def inverse(self):
349
        if not self.is_invertible():
350
            raise RuntimeError(f'{self.name} is not invertible')
351
        new = copy.deepcopy(self)
352
        new.invert_transform = not self.invert_transform
353
        return new
354
355
    @staticmethod
356
    @contextmanager
357
    def _use_seed(seed):
358
        """Perform an operation using a specific seed for the PyTorch RNG"""
359
        torch_rng_state = torch.random.get_rng_state()
360
        torch.manual_seed(seed)
361
        yield
362
        torch.random.set_rng_state(torch_rng_state)
363
364
    @staticmethod
365
    def get_sitk_interpolator(interpolation: str) -> int:
366
        return get_sitk_interpolator(interpolation)
367
368
    @staticmethod
369
    def parse_bounds(bounds_parameters: TypeBounds) -> TypeSixBounds:
370
        try:
371
            bounds_parameters = tuple(bounds_parameters)
372
        except TypeError:
373
            bounds_parameters = (bounds_parameters,)
374
375
        # Check that numbers are integers
376
        for number in bounds_parameters:
377
            if not isinstance(number, (int, np.integer)) or number < 0:
378
                message = (
379
                    'Bounds values must be integers greater or equal to zero,'
380
                    f' not "{bounds_parameters}" of type {type(number)}'
381
                )
382
                raise ValueError(message)
383
        bounds_parameters = tuple(int(n) for n in bounds_parameters)
384
        bounds_parameters_length = len(bounds_parameters)
385
        if bounds_parameters_length == 6:
386
            return bounds_parameters
387
        if bounds_parameters_length == 1:
388
            return 6 * bounds_parameters
389
        if bounds_parameters_length == 3:
390
            return tuple(np.repeat(bounds_parameters, 2).tolist())
391
        message = (
392
            'Bounds parameter must be an integer or a tuple of'
393
            f' 3 or 6 integers, not {bounds_parameters}'
394
        )
395
        raise ValueError(message)
396
397
    @staticmethod
398
    def ones(tensor: torch.Tensor) -> torch.Tensor:
399
        return torch.ones_like(tensor, dtype=torch.bool)
400
401
    @staticmethod
402
    def mean(tensor: torch.Tensor) -> torch.Tensor:
403
        mask = tensor > tensor.mean()
404
        return mask
405
406
    @staticmethod
407
    def get_mask_from_masking_method(
408
            masking_method: TypeMaskingMethod,
409
            subject: Subject,
410
            tensor: torch.Tensor,
411
            ) -> torch.Tensor:
412
        if masking_method is None:
413
            return Transform.ones(tensor)
414
        elif callable(masking_method):
415
            return masking_method(tensor)
416
        elif type(masking_method) is str:
417
            in_subject = masking_method in subject
418
            if in_subject and isinstance(subject[masking_method], LabelMap):
419
                return subject[masking_method].data.bool()
420
            masking_method = masking_method.capitalize()
421
            if masking_method in anat_axes:
422
                return Transform.get_mask_from_anatomical_label(
423
                    masking_method, tensor)
424
        elif type(masking_method) in (tuple, list, int):
425
            return Transform.get_mask_from_bounds(masking_method, tensor)
426
        message = (
427
            'Masking method parameter must be a function, a label map name,'
428
            f' an anatomical label: {anat_axes}, or a bounds parameter'
429
            ' (an int, tuple of 3 ints, or tuple of 6 ints),'
430
            f' not {masking_method} of type {type(masking_method)}'
431
        )
432
        raise ValueError(message)
433
434
    @staticmethod
435
    def get_mask_from_anatomical_label(
436
            anatomical_label: str,
437
            tensor: torch.Tensor,
438
            ) -> torch.Tensor:
439
        anatomical_label = anatomical_label.title()
440
        if anatomical_label.title() not in anat_axes:
441
            message = (
442
                f'Anatomical label must be one of {anat_axes}'
443
                f' not {anatomical_label}'
444
            )
445
            raise ValueError(message)
446
        mask = torch.zeros_like(tensor, dtype=torch.bool)
447
        _, width, height, depth = tensor.shape
448
        if anatomical_label == 'Right':
449
            mask[:, width // 2:] = True
450
        elif anatomical_label == 'Left':
451
            mask[:, :width // 2] = True
452
        elif anatomical_label == 'Anterior':
453
            mask[:, :, height // 2:] = True
454
        elif anatomical_label == 'Posterior':
455
            mask[:, :, :height // 2] = True
456
        elif anatomical_label == 'Superior':
457
            mask[:, :, :, depth // 2:] = True
458
        elif anatomical_label == 'Inferior':
459
            mask[:, :, :, :depth // 2] = True
460
        return mask
461
462
    @staticmethod
463
    def get_mask_from_bounds(
464
            bounds_parameters: TypeBounds,
465
            tensor: torch.Tensor,
466
            ) -> torch.Tensor:
467
        bounds_parameters = Transform.parse_bounds(bounds_parameters)
468
        low = bounds_parameters[::2]
469
        high = bounds_parameters[1::2]
470
        i0, j0, k0 = low
471
        i1, j1, k1 = np.array(tensor.shape[1:]) - high
472
        mask = torch.zeros_like(tensor, dtype=torch.bool)
473
        mask[:, i0:i1, j0:j1, k0:k1] = True
474
        return mask
475