Passed
Pull Request — master (#353)
by Fernando
01:11
created

RandomAffine.__init__()   A

Complexity

Conditions 2

Size

Total Lines 27
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 24
nop 10
dl 0
loc 27
rs 9.304
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, List, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import Interpolation, get_sitk_interpolator
22
from .. import RandomTransform
23
24
25
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
26
27
28
class RandomAffine(RandomTransform, SpatialTransform):
29
    r"""Random affine transformation.
30
31
    Args:
32
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
33
            scaling ranges.
34
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
35
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
36
            If two values :math:`(a, b)` are provided,
37
            then :math:`s_i \sim \mathcal{U}(a, b)`.
38
            If only one value :math:`x` is provided,
39
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
43
            making the objects inside look twice as small while preserving
44
            the physical size and position of the image bounds.
45
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
46
            rotation ranges in degrees.
47
            Rotation angles around each axis are
48
            :math:`(\theta_1, \theta_2, \theta_3)`,
49
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
50
            If two values :math:`(a, b)` are provided,
51
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
52
            If only one value :math:`x` is provided,
53
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
54
            If three values :math:`(x_1, x_2, x_3)` are provided,
55
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
56
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
57
            translation ranges in mm.
58
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
59
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
60
            If two values :math:`(a, b)` are provided,
61
            then :math:`t_i \sim \mathcal{U}(a, b)`.
62
            If only one value :math:`x` is provided,
63
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
64
            If three values :math:`(x_1, x_2, x_3)` are provided,
65
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
66
        isotropic: If ``True``, the scaling factor along all dimensions is the
67
            same, i.e. :math:`s_1 = s_2 = s_3`.
68
        center: If ``'image'``, rotations and scaling will be performed around
69
            the image center. If ``'origin'``, rotations and scaling will be
70
            performed around the origin in world coordinates.
71
        default_pad_value: As the image is rotated, some values near the
72
            borders will be undefined.
73
            If ``'minimum'``, the fill value will be the image minimum.
74
            If ``'mean'``, the fill value is the mean of the border values.
75
            If ``'otsu'``, the fill value is the mean of the values at the
76
            border that lie under an
77
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
78
            If it is a number, that value will be used.
79
        image_interpolation: See :ref:`Interpolation`.
80
        p: Probability that this transform will be applied.
81
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
82
        keys: See :py:class:`~torchio.transforms.Transform`.
83
84
    Example:
85
        >>> import torchio as tio
86
        >>> subject = tio.datasets.Colin27()
87
        >>> transform = tio.RandomAffine(
88
        ...     scales=(0.9, 1.2),
89
        ...     degrees=(10),
90
        ...     isotropic=False,
91
        ...     default_pad_value='otsu',
92
        ...     image_interpolation='bspline',
93
        ... )
94
        >>> transformed = transform(subject)
95
96
    From the command line::
97
98
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "degrees=30 default_pad_value=minimum" --seed 42 affine_min.nii.gz
99
100
    """
101
    def __init__(
102
            self,
103
            scales: TypeOneToSixFloat = 0.1,
104
            degrees: TypeOneToSixFloat = 10,
105
            translation: TypeOneToSixFloat = 0,
106
            isotropic: bool = False,
107
            center: str = 'image',
108
            default_pad_value: Union[str, float] = 'otsu',
109
            image_interpolation: str = 'linear',
110
            p: float = 1,
111
            keys: Optional[List[str]] = None,
112
            ):
113
        super().__init__(p=p, keys=keys)
114
        self.isotropic = isotropic
115
        _parse_scales_isotropic(scales, isotropic)
116
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
117
        self.degrees = self.parse_params(degrees, 0, 'degrees')
118
        self.translation = self.parse_params(translation, 0, 'translation')
119
        if center not in ('image', 'origin'):
120
            message = (
121
                'Center argument must be "image" or "origin",'
122
                f' not "{center}"'
123
            )
124
            raise ValueError(message)
125
        self.center = center
126
        self.default_pad_value = _parse_default_value(default_pad_value)
127
        self.interpolation = self.parse_interpolation(image_interpolation)
128
129
    def get_params(
130
            self,
131
            scales: TypeSextetFloat,
132
            degrees: TypeSextetFloat,
133
            translation: TypeSextetFloat,
134
            isotropic: bool,
135
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
136
        scaling_params = self.sample_uniform_sextet(scales)
137
        if isotropic:
138
            scaling_params.fill_(scaling_params[0])
139
        rotation_params = self.sample_uniform_sextet(degrees)
140
        translation_params = self.sample_uniform_sextet(translation)
141
        return scaling_params, rotation_params, translation_params
142
143
    def apply_transform(self, subject: Subject) -> Subject:
144
        subject.check_consistent_spatial_shape()
145
        scaling_params, rotation_params, translation_params = self.get_params(
146
            self.scales,
147
            self.degrees,
148
            self.translation,
149
            self.isotropic,
150
        )
151
        arguments = dict(
152
            scales=scaling_params.tolist(),
153
            degrees=rotation_params.tolist(),
154
            translation=translation_params.tolist(),
155
            center=self.center,
156
            default_pad_value=self.default_pad_value,
157
            image_interpolation=self.interpolation,
158
        )
159
        transform = Affine(**arguments)
160
        transformed = transform(subject)
161
        transformed.add_transform(transform, arguments)
162
        return transformed
163
164
165
class Affine(SpatialTransform):
166
    r"""Apply affine transformation.
167
168
    Args:
169
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
170
            scaling values along each dimension.
171
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
172
            rotation around each axis.
173
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
174
            translation in mm along each axis.
175
        center: If ``'image'``, rotations and scaling will be performed around
176
            the image center. If ``'origin'``, rotations and scaling will be
177
            performed around the origin in world coordinates.
178
        default_pad_value: As the image is rotated, some values near the
179
            borders will be undefined.
180
            If ``'minimum'``, the fill value will be the image minimum.
181
            If ``'mean'``, the fill value is the mean of the border values.
182
            If ``'otsu'``, the fill value is the mean of the values at the
183
            border that lie under an
184
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
185
            If it is a number, that value will be used.
186
        image_interpolation: See :ref:`Interpolation`.
187
        keys: See :py:class:`~torchio.transforms.Transform`.
188
    """
189
    def __init__(
190
            self,
191
            scales: TypeTripletFloat,
192
            degrees: TypeTripletFloat,
193
            translation: TypeTripletFloat,
194
            center: str = 'image',
195
            default_pad_value: Union[str, float] = 'otsu',
196
            image_interpolation: str = 'linear',
197
            keys: Optional[List[str]] = None,
198
            ):
199
        super().__init__(keys=keys)
200
        self.scales = self.parse_params(
201
            scales,
202
            None,
203
            'scales',
204
            make_ranges=False,
205
            min_constraint=0,
206
        )
207
        self.degrees = self.parse_params(
208
            degrees,
209
            None,
210
            'degrees',
211
            make_ranges=False,
212
        )
213
        self.translation = self.parse_params(translation, 0, 'translation', make_ranges=False,)
214
        if center not in ('image', 'origin'):
215
            message = (
216
                'Center argument must be "image" or "origin",'
217
                f' not "{center}"'
218
            )
219
            raise ValueError(message)
220
        self.use_image_center = center == 'image'
221
        self.default_pad_value = _parse_default_value(default_pad_value)
222
        self.interpolation = self.parse_interpolation(image_interpolation)
223
        self.invert_transform = False
224
225
    @staticmethod
226
    def get_scaling_transform(
227
            scaling_params: List[float],
228
            center_lps: Optional[TypeTripletFloat] = None,
229
            ) -> sitk.ScaleTransform:
230
        # scaling_params are inverted so that they are more intuitive
231
        # For example, 1.5 means the objects look 1.5 times larger
232
        transform = sitk.ScaleTransform(3)
233
        scaling_params = 1 / np.array(scaling_params)
234
        transform.SetScale(scaling_params)
235
        if center_lps is not None:
236
            transform.SetCenter(center_lps)
237
        return transform
238
239
    @staticmethod
240
    def get_rotation_transform(
241
            degrees: List[float],
242
            translation: List[float],
243
            center_lps: Optional[TypeTripletFloat] = None,
244
            ) -> sitk.Euler3DTransform:
245
        transform = sitk.Euler3DTransform()
246
        radians = np.radians(degrees)
247
        transform.SetRotation(*radians)
248
        transform.SetTranslation(translation)
249
        if center_lps is not None:
250
            transform.SetCenter(center_lps)
251
        return transform
252
253
    def apply_transform(self, subject: Subject) -> Subject:
254
        scaling_params = np.array(self.scales).copy()
255
        rotation_params = np.array(self.degrees).copy()
256
        translation_params = np.array(self.translation).copy()
257
        subject.check_consistent_spatial_shape()
258
        for image in self.get_images(subject):
259
            if image[TYPE] != INTENSITY:
260
                interpolation = Interpolation.NEAREST
261
            else:
262
                interpolation = self.interpolation
263
264
            if image.is_2d():
265
                scaling_params[2] = 1
266
                rotation_params[:-1] = 0
267
268
            if self.use_image_center:
269
                center = image.get_center(lps=True)
270
            else:
271
                center = None
272
273
            transformed_tensors = []
274
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
275
                transformed_tensor = self.apply_affine_transform(
276
                    tensor,
277
                    image[AFFINE],
278
                    scaling_params.tolist(),
279
                    rotation_params.tolist(),
280
                    translation_params.tolist(),
281
                    interpolation,
282
                    center_lps=center,
283
                )
284
                transformed_tensors.append(transformed_tensor)
285
            image[DATA] = torch.stack(transformed_tensors)
286
        return subject
287
288
    def apply_affine_transform(
289
            self,
290
            tensor: torch.Tensor,
291
            affine: np.ndarray,
292
            scaling_params: List[float],
293
            rotation_params: List[float],
294
            translation_params: List[float],
295
            interpolation: Interpolation,
296
            center_lps: Optional[TypeTripletFloat] = None,
297
            ) -> torch.Tensor:
298
        assert tensor.ndim == 3
299
300
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
301
        floating = reference = image
302
303
        scaling_transform = self.get_scaling_transform(
304
            scaling_params,
305
            center_lps=center_lps,
306
        )
307
        rotation_transform = self.get_rotation_transform(
308
            rotation_params,
309
            translation_params,
310
            center_lps=center_lps,
311
        )
312
313
        sitk_major_version = get_major_sitk_version()
314
        if sitk_major_version == 1:
315
            transform = sitk.Transform(3, sitk.sitkComposite)
316
            transform.AddTransform(scaling_transform)
317
            transform.AddTransform(rotation_transform)
318
        elif sitk_major_version == 2:
319
            transforms = [scaling_transform, rotation_transform]
320
            transform = sitk.CompositeTransform(transforms)
321
322
        if self.invert_transform:
323
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
324
325
        if self.default_pad_value == 'minimum':
326
            default_value = tensor.min().item()
327
        elif self.default_pad_value == 'mean':
328
            default_value = get_borders_mean(image, filter_otsu=False)
329
        elif self.default_pad_value == 'otsu':
330
            default_value = get_borders_mean(image, filter_otsu=True)
331
        else:
332
            default_value = self.default_pad_value
333
334
        resampler = sitk.ResampleImageFilter()
335
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
336
        resampler.SetReferenceImage(reference)
337
        resampler.SetDefaultPixelValue(float(default_value))
338
        resampler.SetOutputPixelType(sitk.sitkFloat32)
339
        resampler.SetTransform(transform)
340
        resampled = resampler.Execute(floating)
341
342
        np_array = sitk.GetArrayFromImage(resampled)
343
        np_array = np_array.transpose()  # ITK to NumPy
344
        tensor = torch.from_numpy(np_array)
345
        return tensor
346
347
    def inverse(self):
348
        new = copy.deepcopy(self)
349
        new.invert_transform = not self.invert_transform
350
        return new
351
352
353
# flake8: noqa: E201, E203, E243
354
def get_borders_mean(image, filter_otsu=True):
355
    # pylint: disable=bad-whitespace
356
    array = sitk.GetArrayViewFromImage(image)
357
    borders_tuple = (
358
        array[ 0,  :,  :],
359
        array[-1,  :,  :],
360
        array[ :,  0,  :],
361
        array[ :, -1,  :],
362
        array[ :,  :,  0],
363
        array[ :,  :, -1],
364
    )
365
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
366
    if not filter_otsu:
367
        return borders_flat.mean()
368
    borders_reshaped = borders_flat.reshape(1, 1, -1)
369
    borders_image = sitk.GetImageFromArray(borders_reshaped)
370
    otsu = sitk.OtsuThresholdImageFilter()
371
    otsu.Execute(borders_image)
372
    threshold = otsu.GetThreshold()
373
    values = borders_flat[borders_flat < threshold]
374
    if values.any():
375
        default_value = values.mean()
376
    else:
377
        default_value = borders_flat.mean()
378
    return default_value
379
380
def _parse_scales_isotropic(scales, isotropic):
381
    params = to_tuple(scales)
382
    if isotropic and len(scales) in (3, 6):
383
        message = (
384
            'If "isotropic" is True, the value for "scales" must have'
385
            f' length 1 or 2, but "{scales}" was passed'
386
        )
387
        raise ValueError(message)
388
389
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
390
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
391
        return value
392
    message = (
393
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
394
        ' or a number'
395
    )
396
    raise ValueError(message)
397