Passed
Pull Request — master (#353)
by Fernando
01:17
created

RandomAffine.__init__()   A

Complexity

Conditions 2

Size

Total Lines 27
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 24
nop 10
dl 0
loc 27
rs 9.304
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, List, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import Interpolation, get_sitk_interpolator
22
from .. import RandomTransform
23
24
25
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
26
27
28
class RandomAffine(RandomTransform, SpatialTransform):
29
    r"""Random affine transformation.
30
31
    Args:
32
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
33
            scaling ranges.
34
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
35
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
36
            If two values :math:`(a, b)` are provided,
37
            then :math:`s_i \sim \mathcal{U}(a, b)`.
38
            If only one value :math:`x` is provided,
39
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
43
            making the objects inside look twice as small while preserving
44
            the physical size and position of the image bounds.
45
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
46
            rotation ranges in degrees.
47
            Rotation angles around each axis are
48
            :math:`(\theta_1, \theta_2, \theta_3)`,
49
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
50
            If two values :math:`(a, b)` are provided,
51
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
52
            If only one value :math:`x` is provided,
53
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
54
            If three values :math:`(x_1, x_2, x_3)` are provided,
55
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
56
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
57
            translation ranges in mm.
58
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
59
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
60
            If two values :math:`(a, b)` are provided,
61
            then :math:`t_i \sim \mathcal{U}(a, b)`.
62
            If only one value :math:`x` is provided,
63
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
64
            If three values :math:`(x_1, x_2, x_3)` are provided,
65
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
66
        isotropic: If ``True``, the scaling factor along all dimensions is the
67
            same, i.e. :math:`s_1 = s_2 = s_3`.
68
        center: If ``'image'``, rotations and scaling will be performed around
69
            the image center. If ``'origin'``, rotations and scaling will be
70
            performed around the origin in world coordinates.
71
        default_pad_value: As the image is rotated, some values near the
72
            borders will be undefined.
73
            If ``'minimum'``, the fill value will be the image minimum.
74
            If ``'mean'``, the fill value is the mean of the border values.
75
            If ``'otsu'``, the fill value is the mean of the values at the
76
            border that lie under an
77
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
78
            If it is a number, that value will be used.
79
        image_interpolation: See :ref:`Interpolation`.
80
        p: Probability that this transform will be applied.
81
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
82
        keys: See :py:class:`~torchio.transforms.Transform`.
83
84
    Example:
85
        >>> import torchio as tio
86
        >>> subject = tio.datasets.Colin27()
87
        >>> transform = tio.RandomAffine(
88
        ...     scales=(0.9, 1.2),
89
        ...     degrees=10,
90
        ...     isotropic=True,
91
        ...     image_interpolation='nearest',
92
        ... )
93
        >>> transformed = transform(subject)
94
95
    From the command line::
96
97
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "scales=(0.9, 1.2) degrees=10 isotropic=True image_interpolation=nearest" --seed 42 affine_min.nii.gz
98
99
    """
100
    def __init__(
101
            self,
102
            scales: TypeOneToSixFloat = 0.1,
103
            degrees: TypeOneToSixFloat = 10,
104
            translation: TypeOneToSixFloat = 0,
105
            isotropic: bool = False,
106
            center: str = 'image',
107
            default_pad_value: Union[str, float] = 'otsu',
108
            image_interpolation: str = 'linear',
109
            p: float = 1,
110
            keys: Optional[List[str]] = None,
111
            ):
112
        super().__init__(p=p, keys=keys)
113
        self.isotropic = isotropic
114
        _parse_scales_isotropic(scales, isotropic)
115
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
116
        self.degrees = self.parse_params(degrees, 0, 'degrees')
117
        self.translation = self.parse_params(translation, 0, 'translation')
118
        if center not in ('image', 'origin'):
119
            message = (
120
                'Center argument must be "image" or "origin",'
121
                f' not "{center}"'
122
            )
123
            raise ValueError(message)
124
        self.center = center
125
        self.default_pad_value = _parse_default_value(default_pad_value)
126
        self.image_interpolation = self.parse_interpolation(image_interpolation)
127
128
    def get_params(
129
            self,
130
            scales: TypeSextetFloat,
131
            degrees: TypeSextetFloat,
132
            translation: TypeSextetFloat,
133
            isotropic: bool,
134
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
135
        scaling_params = self.sample_uniform_sextet(scales)
136
        if isotropic:
137
            scaling_params.fill_(scaling_params[0])
138
        rotation_params = self.sample_uniform_sextet(degrees)
139
        translation_params = self.sample_uniform_sextet(translation)
140
        return scaling_params, rotation_params, translation_params
141
142
    def apply_transform(self, subject: Subject) -> Subject:
143
        subject.check_consistent_spatial_shape()
144
        scaling_params, rotation_params, translation_params = self.get_params(
145
            self.scales,
146
            self.degrees,
147
            self.translation,
148
            self.isotropic,
149
        )
150
        arguments = dict(
151
            scales=scaling_params.tolist(),
152
            degrees=rotation_params.tolist(),
153
            translation=translation_params.tolist(),
154
            center=self.center,
155
            default_pad_value=self.default_pad_value,
156
            image_interpolation=self.image_interpolation,
157
        )
158
        transform = Affine(**arguments)
159
        transformed = transform(subject)
160
        return transformed
161
162
163
class Affine(SpatialTransform):
164
    r"""Apply affine transformation.
165
166
    Args:
167
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
168
            scaling values along each dimension.
169
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
170
            rotation around each axis.
171
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
172
            translation in mm along each axis.
173
        center: If ``'image'``, rotations and scaling will be performed around
174
            the image center. If ``'origin'``, rotations and scaling will be
175
            performed around the origin in world coordinates.
176
        default_pad_value: As the image is rotated, some values near the
177
            borders will be undefined.
178
            If ``'minimum'``, the fill value will be the image minimum.
179
            If ``'mean'``, the fill value is the mean of the border values.
180
            If ``'otsu'``, the fill value is the mean of the values at the
181
            border that lie under an
182
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
183
            If it is a number, that value will be used.
184
        image_interpolation: See :ref:`Interpolation`.
185
        keys: See :py:class:`~torchio.transforms.Transform`.
186
    """
187
    def __init__(
188
            self,
189
            scales: TypeTripletFloat,
190
            degrees: TypeTripletFloat,
191
            translation: TypeTripletFloat,
192
            center: str = 'image',
193
            default_pad_value: Union[str, float] = 'otsu',
194
            image_interpolation: str = 'linear',
195
            keys: Optional[List[str]] = None,
196
            ):
197
        super().__init__(keys=keys)
198
        self.scales = self.parse_params(
199
            scales,
200
            None,
201
            'scales',
202
            make_ranges=False,
203
            min_constraint=0,
204
        )
205
        self.degrees = self.parse_params(
206
            degrees,
207
            None,
208
            'degrees',
209
            make_ranges=False,
210
        )
211
        self.translation = self.parse_params(
212
            translation,
213
            None,
214
            'translation',
215
            make_ranges=False,
216
        )
217
        if center not in ('image', 'origin'):
218
            message = (
219
                'Center argument must be "image" or "origin",'
220
                f' not "{center}"'
221
            )
222
            raise ValueError(message)
223
        self.center = center
224
        self.use_image_center = center == 'image'
225
        self.default_pad_value = _parse_default_value(default_pad_value)
226
        self.image_interpolation = self.parse_interpolation(image_interpolation)
227
        self.invert_transform = False
228
        self.args_names = (
229
            'scales',
230
            'degrees',
231
            'translation',
232
            'center',
233
            'default_pad_value',
234
            'image_interpolation',
235
        )
236
237
    @staticmethod
238
    def get_scaling_transform(
239
            scaling_params: List[float],
240
            center_lps: Optional[TypeTripletFloat] = None,
241
            ) -> sitk.ScaleTransform:
242
        # scaling_params are inverted so that they are more intuitive
243
        # For example, 1.5 means the objects look 1.5 times larger
244
        transform = sitk.ScaleTransform(3)
245
        scaling_params = 1 / np.array(scaling_params)
246
        transform.SetScale(scaling_params)
247
        if center_lps is not None:
248
            transform.SetCenter(center_lps)
249
        return transform
250
251
    @staticmethod
252
    def get_rotation_transform(
253
            degrees: List[float],
254
            translation: List[float],
255
            center_lps: Optional[TypeTripletFloat] = None,
256
            ) -> sitk.Euler3DTransform:
257
        transform = sitk.Euler3DTransform()
258
        radians = np.radians(degrees)
259
        transform.SetRotation(*radians)
260
        transform.SetTranslation(translation)
261
        if center_lps is not None:
262
            transform.SetCenter(center_lps)
263
        return transform
264
265
    def apply_transform(self, subject: Subject) -> Subject:
266
        scaling_params = np.array(self.scales).copy()
267
        rotation_params = np.array(self.degrees).copy()
268
        translation_params = np.array(self.translation).copy()
269
        subject.check_consistent_spatial_shape()
270
        for image in self.get_images(subject):
271
            if image[TYPE] != INTENSITY:
272
                interpolation = 'nearest'
273
            else:
274
                interpolation = self.image_interpolation
275
276
            if image.is_2d():
277
                scaling_params[2] = 1
278
                rotation_params[:-1] = 0
279
280
            if self.use_image_center:
281
                center = image.get_center(lps=True)
282
            else:
283
                center = None
284
285
            transformed_tensors = []
286
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
287
                transformed_tensor = self.apply_affine_transform(
288
                    tensor,
289
                    image[AFFINE],
290
                    scaling_params.tolist(),
291
                    rotation_params.tolist(),
292
                    translation_params.tolist(),
293
                    interpolation,
294
                    center_lps=center,
295
                )
296
                transformed_tensors.append(transformed_tensor)
297
            image[DATA] = torch.stack(transformed_tensors)
298
        return subject
299
300
    def apply_affine_transform(
301
            self,
302
            tensor: torch.Tensor,
303
            affine: np.ndarray,
304
            scaling_params: List[float],
305
            rotation_params: List[float],
306
            translation_params: List[float],
307
            interpolation: Interpolation,
308
            center_lps: Optional[TypeTripletFloat] = None,
309
            ) -> torch.Tensor:
310
        assert tensor.ndim == 3
311
312
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
313
        floating = reference = image
314
315
        scaling_transform = self.get_scaling_transform(
316
            scaling_params,
317
            center_lps=center_lps,
318
        )
319
        rotation_transform = self.get_rotation_transform(
320
            rotation_params,
321
            translation_params,
322
            center_lps=center_lps,
323
        )
324
325
        sitk_major_version = get_major_sitk_version()
326
        if sitk_major_version == 1:
327
            transform = sitk.Transform(3, sitk.sitkComposite)
328
            transform.AddTransform(scaling_transform)
329
            transform.AddTransform(rotation_transform)
330
        elif sitk_major_version == 2:
331
            transforms = [scaling_transform, rotation_transform]
332
            transform = sitk.CompositeTransform(transforms)
333
334
        if self.invert_transform:
335
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
336
337
        if self.default_pad_value == 'minimum':
338
            default_value = tensor.min().item()
339
        elif self.default_pad_value == 'mean':
340
            default_value = get_borders_mean(image, filter_otsu=False)
341
        elif self.default_pad_value == 'otsu':
342
            default_value = get_borders_mean(image, filter_otsu=True)
343
        else:
344
            default_value = self.default_pad_value
345
346
        resampler = sitk.ResampleImageFilter()
347
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
348
        resampler.SetReferenceImage(reference)
349
        resampler.SetDefaultPixelValue(float(default_value))
350
        resampler.SetOutputPixelType(sitk.sitkFloat32)
351
        resampler.SetTransform(transform)
352
        resampled = resampler.Execute(floating)
353
354
        np_array = sitk.GetArrayFromImage(resampled)
355
        np_array = np_array.transpose()  # ITK to NumPy
356
        tensor = torch.from_numpy(np_array)
357
        return tensor
358
359
360
# flake8: noqa: E201, E203, E243
361
def get_borders_mean(image, filter_otsu=True):
362
    # pylint: disable=bad-whitespace
363
    array = sitk.GetArrayViewFromImage(image)
364
    borders_tuple = (
365
        array[ 0,  :,  :],
366
        array[-1,  :,  :],
367
        array[ :,  0,  :],
368
        array[ :, -1,  :],
369
        array[ :,  :,  0],
370
        array[ :,  :, -1],
371
    )
372
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
373
    if not filter_otsu:
374
        return borders_flat.mean()
375
    borders_reshaped = borders_flat.reshape(1, 1, -1)
376
    borders_image = sitk.GetImageFromArray(borders_reshaped)
377
    otsu = sitk.OtsuThresholdImageFilter()
378
    otsu.Execute(borders_image)
379
    threshold = otsu.GetThreshold()
380
    values = borders_flat[borders_flat < threshold]
381
    if values.any():
382
        default_value = values.mean()
383
    else:
384
        default_value = borders_flat.mean()
385
    return default_value
386
387
def _parse_scales_isotropic(scales, isotropic):
388
    params = to_tuple(scales)
389
    if isotropic and len(scales) in (3, 6):
390
        message = (
391
            'If "isotropic" is True, the value for "scales" must have'
392
            f' length 1 or 2, but "{scales}" was passed'
393
        )
394
        raise ValueError(message)
395
396
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
397
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
398
        return value
399
    message = (
400
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
401
        ' or a number'
402
    )
403
    raise ValueError(message)
404