Passed
Pull Request — master (#353)
by Fernando
01:07
created

RandomAffine.__init__()   A

Complexity

Conditions 2

Size

Total Lines 27
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 24
nop 10
dl 0
loc 27
rs 9.304
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, List, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import Interpolation, get_sitk_interpolator
22
from .. import RandomTransform
23
24
25
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
26
27
28
class RandomAffine(RandomTransform, SpatialTransform):
29
    r"""Random affine transformation.
30
31
    Args:
32
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
33
            scaling ranges.
34
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
35
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
36
            If two values :math:`(a, b)` are provided,
37
            then :math:`s_i \sim \mathcal{U}(a, b)`.
38
            If only one value :math:`x` is provided,
39
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
43
            making the objects inside look twice as small while preserving
44
            the physical size and position of the image bounds.
45
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
46
            rotation ranges in degrees.
47
            Rotation angles around each axis are
48
            :math:`(\theta_1, \theta_2, \theta_3)`,
49
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
50
            If two values :math:`(a, b)` are provided,
51
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
52
            If only one value :math:`x` is provided,
53
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
54
            If three values :math:`(x_1, x_2, x_3)` are provided,
55
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
56
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
57
            translation ranges in mm.
58
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
59
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
60
            If two values :math:`(a, b)` are provided,
61
            then :math:`t_i \sim \mathcal{U}(a, b)`.
62
            If only one value :math:`x` is provided,
63
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
64
            If three values :math:`(x_1, x_2, x_3)` are provided,
65
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
66
        isotropic: If ``True``, the scaling factor along all dimensions is the
67
            same, i.e. :math:`s_1 = s_2 = s_3`.
68
        center: If ``'image'``, rotations and scaling will be performed around
69
            the image center. If ``'origin'``, rotations and scaling will be
70
            performed around the origin in world coordinates.
71
        default_pad_value: As the image is rotated, some values near the
72
            borders will be undefined.
73
            If ``'minimum'``, the fill value will be the image minimum.
74
            If ``'mean'``, the fill value is the mean of the border values.
75
            If ``'otsu'``, the fill value is the mean of the values at the
76
            border that lie under an
77
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
78
            If it is a number, that value will be used.
79
        image_interpolation: See :ref:`Interpolation`.
80
        p: Probability that this transform will be applied.
81
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
82
        keys: See :py:class:`~torchio.transforms.Transform`.
83
84
    Example:
85
        >>> import torchio as tio
86
        >>> subject = tio.datasets.Colin27()
87
        >>> transform = tio.RandomAffine(
88
        ...     scales=(0.9, 1.2),
89
        ...     degrees=10,
90
        ...     isotropic=True,
91
        ...     image_interpolation='nearest',
92
        ... )
93
        >>> transformed = transform(subject)
94
95
    From the command line::
96
97
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "scales=(0.9, 1.2) degrees=10 isotropic=True image_interpolation=nearest" --seed 42 affine_min.nii.gz
98
99
    """
100
    def __init__(
101
            self,
102
            scales: TypeOneToSixFloat = 0.1,
103
            degrees: TypeOneToSixFloat = 10,
104
            translation: TypeOneToSixFloat = 0,
105
            isotropic: bool = False,
106
            center: str = 'image',
107
            default_pad_value: Union[str, float] = 'otsu',
108
            image_interpolation: str = 'linear',
109
            p: float = 1,
110
            keys: Optional[List[str]] = None,
111
            ):
112
        super().__init__(p=p, keys=keys)
113
        self.isotropic = isotropic
114
        _parse_scales_isotropic(scales, isotropic)
115
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
116
        self.degrees = self.parse_params(degrees, 0, 'degrees')
117
        self.translation = self.parse_params(translation, 0, 'translation')
118
        if center not in ('image', 'origin'):
119
            message = (
120
                'Center argument must be "image" or "origin",'
121
                f' not "{center}"'
122
            )
123
            raise ValueError(message)
124
        self.center = center
125
        self.default_pad_value = _parse_default_value(default_pad_value)
126
        self.interpolation = self.parse_interpolation(image_interpolation)
127
128
    def get_params(
129
            self,
130
            scales: TypeSextetFloat,
131
            degrees: TypeSextetFloat,
132
            translation: TypeSextetFloat,
133
            isotropic: bool,
134
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
135
        scaling_params = self.sample_uniform_sextet(scales)
136
        if isotropic:
137
            scaling_params.fill_(scaling_params[0])
138
        rotation_params = self.sample_uniform_sextet(degrees)
139
        translation_params = self.sample_uniform_sextet(translation)
140
        return scaling_params, rotation_params, translation_params
141
142
    def apply_transform(self, subject: Subject) -> Subject:
143
        subject.check_consistent_spatial_shape()
144
        scaling_params, rotation_params, translation_params = self.get_params(
145
            self.scales,
146
            self.degrees,
147
            self.translation,
148
            self.isotropic,
149
        )
150
        arguments = dict(
151
            scales=scaling_params.tolist(),
152
            degrees=rotation_params.tolist(),
153
            translation=translation_params.tolist(),
154
            center=self.center,
155
            default_pad_value=self.default_pad_value,
156
            image_interpolation=self.interpolation,
157
        )
158
        transform = Affine(**arguments)
159
        transformed = transform(subject)
160
        return transformed
161
162
163
class Affine(SpatialTransform):
164
    r"""Apply affine transformation.
165
166
    Args:
167
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
168
            scaling values along each dimension.
169
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
170
            rotation around each axis.
171
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
172
            translation in mm along each axis.
173
        center: If ``'image'``, rotations and scaling will be performed around
174
            the image center. If ``'origin'``, rotations and scaling will be
175
            performed around the origin in world coordinates.
176
        default_pad_value: As the image is rotated, some values near the
177
            borders will be undefined.
178
            If ``'minimum'``, the fill value will be the image minimum.
179
            If ``'mean'``, the fill value is the mean of the border values.
180
            If ``'otsu'``, the fill value is the mean of the values at the
181
            border that lie under an
182
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
183
            If it is a number, that value will be used.
184
        image_interpolation: See :ref:`Interpolation`.
185
        keys: See :py:class:`~torchio.transforms.Transform`.
186
    """
187
    def __init__(
188
            self,
189
            scales: TypeTripletFloat,
190
            degrees: TypeTripletFloat,
191
            translation: TypeTripletFloat,
192
            center: str = 'image',
193
            default_pad_value: Union[str, float] = 'otsu',
194
            image_interpolation: str = 'linear',
195
            keys: Optional[List[str]] = None,
196
            ):
197
        super().__init__(keys=keys)
198
        self.scales = self.parse_params(
199
            scales,
200
            None,
201
            'scales',
202
            make_ranges=False,
203
            min_constraint=0,
204
        )
205
        self.degrees = self.parse_params(
206
            degrees,
207
            None,
208
            'degrees',
209
            make_ranges=False,
210
        )
211
        self.translation = self.parse_params(
212
            translation,
213
            None,
214
            'translation',
215
            make_ranges=False,
216
        )
217
        if center not in ('image', 'origin'):
218
            message = (
219
                'Center argument must be "image" or "origin",'
220
                f' not "{center}"'
221
            )
222
            raise ValueError(message)
223
        self.center = center
224
        self.use_image_center = center == 'image'
225
        self.default_pad_value = _parse_default_value(default_pad_value)
226
        self.interpolation = self.parse_interpolation(image_interpolation)
227
        self.invert_transform = False
228
229
    def get_arguments(self):
230
        arguments = dict(
231
            scales=self.scales,
232
            degrees=self.degrees,
233
            translation=self.translation,
234
            center=self.center,
235
            default_pad_value=self.default_pad_value,
236
            image_interpolation=self.interpolation,
237
        )
238
        return arguments
239
240
    @staticmethod
241
    def get_scaling_transform(
242
            scaling_params: List[float],
243
            center_lps: Optional[TypeTripletFloat] = None,
244
            ) -> sitk.ScaleTransform:
245
        # scaling_params are inverted so that they are more intuitive
246
        # For example, 1.5 means the objects look 1.5 times larger
247
        transform = sitk.ScaleTransform(3)
248
        scaling_params = 1 / np.array(scaling_params)
249
        transform.SetScale(scaling_params)
250
        if center_lps is not None:
251
            transform.SetCenter(center_lps)
252
        return transform
253
254
    @staticmethod
255
    def get_rotation_transform(
256
            degrees: List[float],
257
            translation: List[float],
258
            center_lps: Optional[TypeTripletFloat] = None,
259
            ) -> sitk.Euler3DTransform:
260
        transform = sitk.Euler3DTransform()
261
        radians = np.radians(degrees)
262
        transform.SetRotation(*radians)
263
        transform.SetTranslation(translation)
264
        if center_lps is not None:
265
            transform.SetCenter(center_lps)
266
        return transform
267
268
    def apply_transform(self, subject: Subject) -> Subject:
269
        scaling_params = np.array(self.scales).copy()
270
        rotation_params = np.array(self.degrees).copy()
271
        translation_params = np.array(self.translation).copy()
272
        subject.check_consistent_spatial_shape()
273
        for image in self.get_images(subject):
274
            if image[TYPE] != INTENSITY:
275
                interpolation = 'nearest'
276
            else:
277
                interpolation = self.interpolation
278
279
            if image.is_2d():
280
                scaling_params[2] = 1
281
                rotation_params[:-1] = 0
282
283
            if self.use_image_center:
284
                center = image.get_center(lps=True)
285
            else:
286
                center = None
287
288
            transformed_tensors = []
289
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
290
                transformed_tensor = self.apply_affine_transform(
291
                    tensor,
292
                    image[AFFINE],
293
                    scaling_params.tolist(),
294
                    rotation_params.tolist(),
295
                    translation_params.tolist(),
296
                    interpolation,
297
                    center_lps=center,
298
                )
299
                transformed_tensors.append(transformed_tensor)
300
            image[DATA] = torch.stack(transformed_tensors)
301
        return subject
302
303
    def apply_affine_transform(
304
            self,
305
            tensor: torch.Tensor,
306
            affine: np.ndarray,
307
            scaling_params: List[float],
308
            rotation_params: List[float],
309
            translation_params: List[float],
310
            interpolation: Interpolation,
311
            center_lps: Optional[TypeTripletFloat] = None,
312
            ) -> torch.Tensor:
313
        assert tensor.ndim == 3
314
315
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
316
        floating = reference = image
317
318
        scaling_transform = self.get_scaling_transform(
319
            scaling_params,
320
            center_lps=center_lps,
321
        )
322
        rotation_transform = self.get_rotation_transform(
323
            rotation_params,
324
            translation_params,
325
            center_lps=center_lps,
326
        )
327
328
        sitk_major_version = get_major_sitk_version()
329
        if sitk_major_version == 1:
330
            transform = sitk.Transform(3, sitk.sitkComposite)
331
            transform.AddTransform(scaling_transform)
332
            transform.AddTransform(rotation_transform)
333
        elif sitk_major_version == 2:
334
            transforms = [scaling_transform, rotation_transform]
335
            transform = sitk.CompositeTransform(transforms)
336
337
        if self.invert_transform:
338
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
339
340
        if self.default_pad_value == 'minimum':
341
            default_value = tensor.min().item()
342
        elif self.default_pad_value == 'mean':
343
            default_value = get_borders_mean(image, filter_otsu=False)
344
        elif self.default_pad_value == 'otsu':
345
            default_value = get_borders_mean(image, filter_otsu=True)
346
        else:
347
            default_value = self.default_pad_value
348
349
        resampler = sitk.ResampleImageFilter()
350
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
351
        resampler.SetReferenceImage(reference)
352
        resampler.SetDefaultPixelValue(float(default_value))
353
        resampler.SetOutputPixelType(sitk.sitkFloat32)
354
        resampler.SetTransform(transform)
355
        resampled = resampler.Execute(floating)
356
357
        np_array = sitk.GetArrayFromImage(resampled)
358
        np_array = np_array.transpose()  # ITK to NumPy
359
        tensor = torch.from_numpy(np_array)
360
        return tensor
361
362
363
# flake8: noqa: E201, E203, E243
364
def get_borders_mean(image, filter_otsu=True):
365
    # pylint: disable=bad-whitespace
366
    array = sitk.GetArrayViewFromImage(image)
367
    borders_tuple = (
368
        array[ 0,  :,  :],
369
        array[-1,  :,  :],
370
        array[ :,  0,  :],
371
        array[ :, -1,  :],
372
        array[ :,  :,  0],
373
        array[ :,  :, -1],
374
    )
375
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
376
    if not filter_otsu:
377
        return borders_flat.mean()
378
    borders_reshaped = borders_flat.reshape(1, 1, -1)
379
    borders_image = sitk.GetImageFromArray(borders_reshaped)
380
    otsu = sitk.OtsuThresholdImageFilter()
381
    otsu.Execute(borders_image)
382
    threshold = otsu.GetThreshold()
383
    values = borders_flat[borders_flat < threshold]
384
    if values.any():
385
        default_value = values.mean()
386
    else:
387
        default_value = borders_flat.mean()
388
    return default_value
389
390
def _parse_scales_isotropic(scales, isotropic):
391
    params = to_tuple(scales)
392
    if isotropic and len(scales) in (3, 6):
393
        message = (
394
            'If "isotropic" is True, the value for "scales" must have'
395
            f' length 1 or 2, but "{scales}" was passed'
396
        )
397
        raise ValueError(message)
398
399
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
400
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
401
        return value
402
    message = (
403
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
404
        ' or a number'
405
    )
406
    raise ValueError(message)
407