Passed
Push — master ( 357cca...19ff6e )
by Fernando
01:52
created

Transform.parse_range()   F

Complexity

Conditions 21

Size

Total Lines 102
Code Lines 48

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 21
eloc 48
nop 5
dl 0
loc 102
rs 0
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like torchio.transforms.transform.Transform.parse_range() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import copy
2
import numbers
3
import warnings
4
from abc import ABC, abstractmethod
5
from typing import Optional, Union, Tuple, List
6
7
import torch
8
import numpy as np
9
import SimpleITK as sitk
10
11
from .. import TypeData, DATA, TypeNumber
12
from ..data.subject import Subject
13
from ..data.image import Image, ScalarImage
14
from ..data.dataset import ImagesDataset
15
from ..utils import nib_to_sitk, sitk_to_nib
16
from .interpolation import Interpolation
17
18
19
IMAGE_NAME = 'image'
20
21
22
class Transform(ABC):
23
    image_name = 'image'
24
    """Abstract class for all TorchIO transforms.
25
26
    All classes used to transform a sample from an
27
    :py:class:`~torchio.ImagesDataset` should subclass it.
28
    All subclasses should overwrite
29
    :py:meth:`torchio.tranforms.Transform.apply_transform`,
30
    which takes a sample, applies some transformation and returns the result.
31
32
    Args:
33
        p: Probability that this transform will be applied.
34
        copy: Make a deep copy of the input before applying the transform.
35
        keys: If the input is a dictionary, the corresponding values will be
36
            converted to :py:class:`torchio.ScalarImage` so that the transform
37
            is applied to them only.
38
    """
39
    def __init__(
40
            self,
41
            p: float = 1,
42
            copy: bool = True,
43
            keys: Optional[List[str]] = None,
44
            ):
45
        self.probability = self.parse_probability(p)
46
        self.copy = copy
47
        self.keys = keys
48
49
    def __call__(self, data: Union[Subject, torch.Tensor, np.ndarray]):
50
        """Transform a sample and return the result.
51
52
        Args:
53
            data: Instance of :py:class:`~torchio.Subject`, 4D
54
                :py:class:`torch.Tensor` or 4D NumPy array with dimensions
55
                :math:`(C, D, H, W)`, where :math:`C` is the number of channels
56
                and :math:`D, H, W` are the spatial dimensions. If the input is
57
                a tensor, the affine matrix is an identity and a tensor will be
58
                also returned.
59
        """
60
        if torch.rand(1).item() > self.probability:
61
            return data
62
63
        is_dict = False
64
        if isinstance(data, (np.ndarray, torch.Tensor)):
65
            is_array = isinstance(data, np.ndarray)
66
            is_tensor = True
67
            is_image = False
68
            sample = self.parse_tensor(data)
69
        elif isinstance(data, Image):
70
            sample = self._get_subject_from_image(data)
71
            is_tensor = is_array = False
72
            is_image = True
73
        elif isinstance(data, Subject):
74
            is_tensor = is_array = is_image = False
75
            sample = data
76
        elif isinstance(data, dict):  # e.g. Eisen dict
77
            if self.keys is None:
78
                message = (
79
                    'If input is a dictionary, a value for "keys" must be'
80
                    ' specified when instantiating the transform'
81
                )
82
                raise RuntimeError(message)
83
            sample = self._get_subject_from_dict(data, self.keys)
84
            is_tensor = is_array = is_image = False
85
            is_dict = True
86
        self.parse_sample(sample)
0 ignored issues
show
introduced by
The variable sample does not seem to be defined for all execution paths.
Loading history...
87
88
        if self.copy:
89
            sample = copy.copy(sample)
90
91
        with np.errstate(all='raise'):
92
            transformed = self.apply_transform(sample)
93
94
        for image in transformed.get_images(intensity_only=False):
95
            ndim = image[DATA].ndim
96
            assert ndim == 4, f'Output of {self.name} is {ndim}D'
97
98
        if is_tensor:
0 ignored issues
show
introduced by
The variable is_tensor does not seem to be defined for all execution paths.
Loading history...
99
            transformed = transformed[self.image_name][DATA]
100
        if is_array:
0 ignored issues
show
introduced by
The variable is_array does not seem to be defined for all execution paths.
Loading history...
101
            transformed = transformed.numpy()
102
        if is_image:
0 ignored issues
show
introduced by
The variable is_image does not seem to be defined for all execution paths.
Loading history...
103
            transformed = transformed[IMAGE_NAME]
104
        if is_dict:
105
            transformed = dict(transformed)
106
            for key, value in transformed.items():
107
                if isinstance(value, Image):
108
                    transformed[key] = value.data
109
        return transformed
110
111
    @abstractmethod
112
    def apply_transform(self, sample: Subject):
113
        raise NotImplementedError
114
115
    @staticmethod
116
    def parse_range(
117
            nums_range: Union[TypeNumber, Tuple[TypeNumber, TypeNumber]],
118
            name: str,
119
            min_constraint: TypeNumber = None,
120
            max_constraint: TypeNumber = None,
121
            type_constraint: type = None,
122
            ) -> Tuple[TypeNumber, TypeNumber]:
123
        r"""Adapted from ``torchvision.transforms.RandomRotation``.
124
125
        Args:
126
            nums_range: Tuple of two numbers :math:`(n_{min}, n_{max})`,
127
                where :math:`n_{min} \leq n_{max}`.
128
                If a single positive number :math:`n` is provided,
129
                :math:`n_{min} = -n` and :math:`n_{max} = n`.
130
            name: Name of the parameter, so that an informative error message
131
                can be printed.
132
            min_constraint: Minimal value that :math:`n_{min}` can take,
133
                default is None, i.e. there is no minimal value.
134
            max_constraint: Maximal value that :math:`n_{max}` can take,
135
                default is None, i.e. there is no maximal value.
136
            type_constraint: Precise type that :math:`n_{max}` and
137
                :math:`n_{min}` must take.
138
139
        Returns:
140
            A tuple of two numbers :math:`(n_{min}, n_{max})`.
141
142
        Raises:
143
            ValueError: if :attr:`nums_range` is negative
144
            ValueError: if :math:`n_{max}` or :math:`n_{min}` is not a number
145
            ValueError: if :math:`n_{max} \lt n_{min}`
146
            ValueError: if :attr:`min_constraint` is not None and
147
                :math:`n_{min}` is smaller than :attr:`min_constraint`
148
            ValueError: if :attr:`max_constraint` is not None and
149
                :math:`n_{max}` is greater than :attr:`max_constraint`
150
            ValueError: if :attr:`type_constraint` is not None and
151
                :math:`n_{max}` and :math:`n_{max}` are not of type
152
                :attr:`type_constraint`.
153
        """
154
        if isinstance(nums_range, numbers.Number):
155
            if nums_range < 0:
156
                raise ValueError(
157
                    f'If {name} is a single number,'
158
                    f' it must be positive, not {nums_range}')
159
            if min_constraint is not None and nums_range < min_constraint:
160
                raise ValueError(
161
                    f'If {name} is a single number, it must be greater'
162
                    f'than {min_constraint}, not {nums_range}'
163
                )
164
            if max_constraint is not None and nums_range > max_constraint:
165
                raise ValueError(
166
                    f'If {name} is a single number, it must be smaller'
167
                    f'than {max_constraint}, not {nums_range}'
168
                )
169
            if type_constraint is not None and \
170
                    not isinstance(nums_range, type_constraint):
171
                raise ValueError(
172
                    f'If {name} is a single number, it must be of'
173
                    f'type {type_constraint}, not {nums_range}'
174
                )
175
            min_range = -nums_range if min_constraint is None else nums_range
176
            return (min_range, nums_range)
177
178
        try:
179
            min_degree, max_degree = nums_range
180
        except (TypeError, ValueError):
181
            raise ValueError(
182
                f'If {name} is not a single number, it muste be'
183
                f'a sequence of len 2, not {nums_range}'
184
            )
185
186
        if not isinstance(min_degree, numbers.Number) or \
187
                not isinstance(max_degree, numbers.Number):
188
            message = (
189
                f'{name} values must be numbers, not {nums_range}')
190
            raise ValueError(message)
191
192
        if min_degree > max_degree:
193
            raise ValueError(
194
                f'If {name} is a sequence, the second value must be'
195
                f' equal or greater than the first, not {nums_range}')
196
197
        if min_constraint is not None and min_degree < min_constraint:
198
            raise ValueError(
199
                f'If {name} is a sequence, the first value must be greater'
200
                f'than {min_constraint}, not {min_degree}'
201
            )
202
203
        if max_constraint is not None and max_degree > max_constraint:
204
            raise ValueError(
205
                f'If {name} is a sequence, the second value must be smaller'
206
                f'than {max_constraint}, not {max_degree}'
207
            )
208
209
        if type_constraint is not None:
210
            if not isinstance(min_degree, type_constraint) or \
211
                    not isinstance(max_degree, type_constraint):
212
                raise ValueError(
213
                    f'If {name} is a sequence, its values must be of'
214
                    f'type {type_constraint}, not {nums_range}'
215
                )
216
        return nums_range
217
218
    @staticmethod
219
    def parse_probability(probability: float) -> float:
220
        is_number = isinstance(probability, numbers.Number)
221
        if not (is_number and 0 <= probability <= 1):
222
            message = (
223
                'Probability must be a number in [0, 1],'
224
                f' not {probability}'
225
            )
226
            raise ValueError(message)
227
        return probability
228
229
    @staticmethod
230
    def parse_sample(sample: Subject) -> None:
231
        if not isinstance(sample, Subject):
232
            message = (
233
                'Input to a transform must be a tensor or an instance'
234
                f' of torchio.Subject, not "{type(sample)}"'
235
            )
236
            raise RuntimeError(message)
237
238
    def parse_tensor(self, data: TypeData) -> Subject:
239
        if data.ndim != 4:
240
            message = (
241
                'The input must be a 4D tensor with dimensions'
242
                f' (channels, x, y, z) but it has shape {tuple(data.shape)}'
243
            )
244
            raise ValueError(message)
245
        return self._get_subject_from_tensor(data)
246
247
    @staticmethod
248
    def parse_interpolation(interpolation: str) -> Interpolation:
249
        if isinstance(interpolation, Interpolation):
250
            message = (
251
                'Interpolation of type torchio.Interpolation'
252
                ' is deprecated, please use a string instead'
253
            )
254
            warnings.warn(message, FutureWarning)
255
        elif isinstance(interpolation, str):
256
            interpolation = interpolation.lower()
257
            supported_values = [key.name.lower() for key in Interpolation]
258
            if interpolation in supported_values:
259
                interpolation = getattr(Interpolation, interpolation.upper())
260
            else:
261
                message = (
262
                    f'Interpolation "{interpolation}" is not among'
263
                    f' the supported values: {supported_values}'
264
                )
265
                raise AttributeError(message)
266
        else:
267
            message = (
268
                'image_interpolation must be a string,'
269
                f' not {type(interpolation)}'
270
            )
271
            raise TypeError(message)
272
        return interpolation
273
274
    def _get_subject_from_tensor(self, tensor: torch.Tensor) -> Subject:
275
        image = ScalarImage(tensor=tensor, channels_last=False)
276
        return Subject({self.image_name: image})
277
278
    @staticmethod
279
    def _get_subject_from_image(image: Image) -> Subject:
280
        subject = Subject({IMAGE_NAME: image})
281
        return subject
282
283
    @staticmethod
284
    def _get_subject_from_dict(
285
            data: dict,
286
            image_keys: List[str],
287
            ) -> Subject:
288
        subject_dict = {}
289
        for key, value in data.items():
290
            if key in image_keys:
291
                value = ScalarImage(tensor=value)
292
            subject_dict[key] = value
293
        return Subject(subject_dict)
294
295
    @staticmethod
296
    def nib_to_sitk(data: TypeData, affine: TypeData) -> sitk.Image:
297
        return nib_to_sitk(data, affine)
298
299
    @staticmethod
300
    def sitk_to_nib(image: sitk.Image) -> Tuple[torch.Tensor, np.ndarray]:
301
        return sitk_to_nib(image)
302
303
    @property
304
    def name(self):
305
        return self.__class__.__name__
306