Passed
Pull Request — master (#371)
by Fernando
01:11
created

Transform.parse_interpolation()   A

Complexity

Conditions 4

Size

Total Lines 16
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 14
nop 1
dl 0
loc 16
rs 9.7
c 0
b 0
f 0
1
import copy
2
import numbers
3
from abc import ABC, abstractmethod
4
from contextlib import contextmanager
5
from typing import Optional, Union, Tuple, Sequence
6
7
import torch
8
import numpy as np
9
import SimpleITK as sitk
10
11
from ..data.subject import Subject
12
from .. import TypeData, TypeNumber
13
from ..utils import nib_to_sitk, sitk_to_nib, to_tuple
14
from .interpolation import Interpolation, get_sitk_interpolator
15
from .data_parser import DataParser, TypeTransformInput
16
17
18
class Transform(ABC):
19
    """Abstract class for all TorchIO transforms.
20
21
    All subclasses should overwrite
22
    :meth:`torchio.tranforms.Transform.apply_transform`,
23
    which takes data, applies some transformation and returns the result.
24
25
    The input can be an instance of
26
    :class:`torchio.Subject`,
27
    :class:`torchio.Image`,
28
    :class:`numpy.ndarray`,
29
    :class:`torch.Tensor`,
30
    :class:`SimpleITK.Image`,
31
    or :class:`dict`.
32
33
    Args:
34
        p: Probability that this transform will be applied.
35
        copy: Make a shallow copy of the input before applying the transform.
36
        keys: Mandatory if the input is a :class:`dict`. The transform will
37
            be applied only to the data in each key.
38
    """
39
    def __init__(
40
            self,
41
            p: float = 1,
42
            copy: bool = True,
43
            keys: Optional[Sequence[str]] = None,
44
            ):
45
        self.probability = self.parse_probability(p)
46
        self.copy = copy
47
        self.keys = keys
48
49
    def __call__(
50
            self,
51
            data: TypeTransformInput,
52
            ) -> TypeTransformInput:
53
        """Transform data and return a result of the same type.
54
55
        Args:
56
            data: Instance of 1) :class:`~torchio.Subject`, 4D
57
                :class:`torch.Tensor` or :class:`numpy.ndarray` with dimensions
58
                :math:`(C, W, H, D)`, where :math:`C` is the number of channels
59
                and :math:`W, H, D` are the spatial dimensions. If the input is
60
                a tensor, the affine matrix will be set to identity. Other
61
                valid input types are a SimpleITK image, a
62
                :class:`torchio.Image`, a NiBabel Nifti1 image or a
63
                :class:`dict`. The output type is the same as the input type.
64
        """
65
        if torch.rand(1).item() > self.probability:
66
            return data
67
        data_parser = DataParser(data, keys=self.keys)
68
        subject = data_parser.get_subject()
69
        if self.copy:
70
            subject = copy.copy(subject)
71
        with np.errstate(all='raise'):
72
            transformed = self.apply_transform(subject)
73
        self.add_transform_to_subject_history(transformed)
74
        for image in transformed.get_images(intensity_only=False):
75
            ndim = image.data.ndim
76
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
77
            dtype = image.data.dtype
78
            assert dtype is torch.float32, f'Output of {self.name} is {dtype}'
79
80
        output = data_parser.get_output(transformed)
81
        return output
82
83
    def __repr__(self):
84
        if hasattr(self, 'args_names'):
85
            names = self.args_names
86
            args_strings = [f'{arg}={getattr(self, arg)}' for arg in names]
87
            if hasattr(self, 'invert_transform') and self.invert_transform:
88
                args_strings.append('invert=True')
89
            args_string = ', '.join(args_strings)
90
            return f'{self.name}({args_string})'
91
        else:
92
            return super().__repr__()
93
94
    @property
95
    def name(self):
96
        return self.__class__.__name__
97
98
    @abstractmethod
99
    def apply_transform(self, subject: Subject):
100
        raise NotImplementedError
101
102
    def add_transform_to_subject_history(self, subject):
103
        from .augmentation import RandomTransform
104
        from . import Compose, OneOf, CropOrPad
105
        call_others = (
106
            RandomTransform,
107
            Compose,
108
            OneOf,
109
            CropOrPad,
110
        )
111
        if not isinstance(self, call_others):
112
            subject.add_transform(self, self._get_reproducing_arguments())
113
114
    @staticmethod
115
    def to_range(n, around):
116
        if around is None:
117
            return 0, n
118
        else:
119
            return around - n, around + n
120
121
    def parse_params(self, params, around, name, make_ranges=True, **kwargs):
122
        params = to_tuple(params)
123
        if len(params) == 1 or (len(params) == 2 and make_ranges):  # d or (a, b)
124
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
125
        if len(params) == 3 and make_ranges:  # (a, b, c)
126
            items = [self.to_range(n, around) for n in params]
127
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
128
            params = [n for prange in items for n in prange]
129
        if make_ranges:
130
            if len(params) != 6:
131
                message = (
132
                    f'If "{name}" is a sequence, it must have length 2, 3 or 6,'
133
                    f' not {len(params)}'
134
                )
135
                raise ValueError(message)
136
            for param_range in zip(params[::2], params[1::2]):
137
                self._parse_range(param_range, name, **kwargs)
138
        return tuple(params)
139
140
    @staticmethod
141
    def _parse_range(
142
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
143
            name: str,
144
            min_constraint: TypeNumber = None,
145
            max_constraint: TypeNumber = None,
146
            type_constraint: type = None,
147
            ) -> Tuple[TypeNumber, TypeNumber]:
148
        r"""Adapted from ``torchvision.transforms.RandomRotation``.
149
150
        Args:
151
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
152
                where :math:`n_{min} \leq n_{max}`.
153
                If a single positive number :math:`n` is provided,
154
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
155
            name: Name of the parameter, so that an informative error message
156
                can be printed.
157
            min_constraint: Minimal value that :math:`n_{min}` can take,
158
                default is None, i.e. there is no minimal value.
159
            max_constraint: Maximal value that :math:`n_{max}` can take,
160
                default is None, i.e. there is no maximal value.
161
            type_constraint: Precise type that :math:`n_{max}` and
162
                :math:`n_{min}` must take.
163
164
        Returns:
165
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
166
167
        Raises:
168
            ValueError: if :attr:`nums_range` is negative
169
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
170
            ValueError: if :math:`n_{max} \lt n_{min}`
171
            ValueError: if :attr:`min_constraint` is not None and
172
                :math:`n_{min}` is smaller than :attr:`min_constraint`
173
            ValueError: if :attr:`max_constraint` is not None and
174
                :math:`n_{max}` is greater than :attr:`max_constraint`
175
            ValueError: if :attr:`type_constraint` is not None and
176
                :math:`n_{max}` and :math:`n_{max}` are not of type
177
                :attr:`type_constraint`.
178
        """
179
        if isinstance(nums_range, numbers.Number):  # single number given
180
            if nums_range < 0:
181
                raise ValueError(
182
                    f'If {name} is a single number,'
183
                    f' it must be positive, not {nums_range}')
184
            if min_constraint is not None and nums_range < min_constraint:
185
                raise ValueError(
186
                    f'If {name} is a single number, it must be greater'
187
                    f' than {min_constraint}, not {nums_range}'
188
                )
189
            if max_constraint is not None and nums_range > max_constraint:
190
                raise ValueError(
191
                    f'If {name} is a single number, it must be smaller'
192
                    f' than {max_constraint}, not {nums_range}'
193
                )
194
            if type_constraint is not None:
195
                if not isinstance(nums_range, type_constraint):
196
                    raise ValueError(
197
                        f'If {name} is a single number, it must be of'
198
                        f' type {type_constraint}, not {nums_range}'
199
                    )
200
            min_range = -nums_range if min_constraint is None else nums_range
201
            return (min_range, nums_range)
202
203
        try:
204
            min_value, max_value = nums_range
205
        except (TypeError, ValueError):
206
            raise ValueError(
207
                f'If {name} is not a single number, it must be'
208
                f' a sequence of len 2, not {nums_range}'
209
            )
210
211
        min_is_number = isinstance(min_value, numbers.Number)
212
        max_is_number = isinstance(max_value, numbers.Number)
213
        if not min_is_number or not max_is_number:
214
            message = (
215
                f'{name} values must be numbers, not {nums_range}')
216
            raise ValueError(message)
217
218
        if min_value > max_value:
219
            raise ValueError(
220
                f'If {name} is a sequence, the second value must be'
221
                f' equal or greater than the first, but it is {nums_range}')
222
223
        if min_constraint is not None and min_value < min_constraint:
224
            raise ValueError(
225
                f'If {name} is a sequence, the first value must be greater'
226
                f' than {min_constraint}, but it is {min_value}'
227
            )
228
229
        if max_constraint is not None and max_value > max_constraint:
230
            raise ValueError(
231
                f'If {name} is a sequence, the second value must be smaller'
232
                f' than {max_constraint}, but it is {max_value}'
233
            )
234
235
        if type_constraint is not None:
236
            min_type_ok = isinstance(min_value, type_constraint)
237
            max_type_ok = isinstance(max_value, type_constraint)
238
            if not min_type_ok or not max_type_ok:
239
                raise ValueError(
240
                    f'If "{name}" is a sequence, its values must be of'
241
                    f' type "{type_constraint}", not "{type(nums_range)}"'
242
                )
243
        return nums_range
244
245
    @staticmethod
246
    def parse_interpolation(interpolation: str) -> str:
247
        if not isinstance(interpolation, str):
248
            itype = type(interpolation)
249
            raise TypeError(f'Interpolation must be a string, not {itype}')
250
        interpolation = interpolation.lower()
251
        is_string = isinstance(interpolation, str)
252
        supported_values = [key.name.lower() for key in Interpolation]
253
        is_supported = interpolation.lower() in supported_values
254
        if is_string and is_supported:
255
            return interpolation
256
        message = (
257
            f'Interpolation "{interpolation}" of type {type(interpolation)}'
258
            f' must be a string among the supported values: {supported_values}'
259
        )
260
        raise ValueError(message)
261
262
    @staticmethod
263
    def parse_probability(probability: float) -> float:
264
        is_number = isinstance(probability, numbers.Number)
265
        if not (is_number and 0 <= probability <= 1):
266
            message = (
267
                'Probability must be a number in [0, 1],'
268
                f' not {probability}'
269
            )
270
            raise ValueError(message)
271
        return probability
272
273
    @staticmethod
274
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
275
        return nib_to_sitk(data, affine)
276
277
    @staticmethod
278
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
279
        return sitk_to_nib(image)
280
281
    def _get_reproducing_arguments(self):
282
        """
283
        Return a dictionary with the arguments that would be necessary to
284
        reproduce the transform exactly.
285
        """
286
        return {name: getattr(self, name) for name in self.args_names}
287
288
    def is_invertible(self):
289
        return hasattr(self, 'invert_transform')
290
291
    def inverse(self):
292
        if not self.is_invertible():
293
            raise RuntimeError(f'{self.name} is not invertible')
294
        new = copy.deepcopy(self)
295
        new.invert_transform = not self.invert_transform
296
        return new
297
298
    @staticmethod
299
    @contextmanager
300
    def _use_seed(seed):
301
        """Perform an operation using a specific seed for the PyTorch RNG"""
302
        torch_rng_state = torch.random.get_rng_state()
303
        torch.manual_seed(seed)
304
        yield
305
        torch.random.set_rng_state(torch_rng_state)
306
307
    @staticmethod
308
    def get_sitk_interpolator(interpolation: str) -> int:
309
        return get_sitk_interpolator(interpolation)
310