Passed
Pull Request — master (#353)
by Fernando
01:16
created

RandomAffine.__init__()   A

Complexity

Conditions 2

Size

Total Lines 27
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 24
nop 10
dl 0
loc 27
rs 9.304
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, Sequence, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import Interpolation, get_sitk_interpolator
22
from .. import RandomTransform
23
24
25
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
26
27
28
class RandomAffine(RandomTransform, SpatialTransform):
29
    r"""Random affine transformation.
30
31
    Args:
32
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
33
            scaling ranges.
34
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
35
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
36
            If two values :math:`(a, b)` are provided,
37
            then :math:`s_i \sim \mathcal{U}(a, b)`.
38
            If only one value :math:`x` is provided,
39
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
43
            making the objects inside look twice as small while preserving
44
            the physical size and position of the image bounds.
45
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
46
            rotation ranges in degrees.
47
            Rotation angles around each axis are
48
            :math:`(\theta_1, \theta_2, \theta_3)`,
49
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
50
            If two values :math:`(a, b)` are provided,
51
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
52
            If only one value :math:`x` is provided,
53
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
54
            If three values :math:`(x_1, x_2, x_3)` are provided,
55
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
56
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
57
            translation ranges in mm.
58
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
59
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
60
            If two values :math:`(a, b)` are provided,
61
            then :math:`t_i \sim \mathcal{U}(a, b)`.
62
            If only one value :math:`x` is provided,
63
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
64
            If three values :math:`(x_1, x_2, x_3)` are provided,
65
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
66
        isotropic: If ``True``, the scaling factor along all dimensions is the
67
            same, i.e. :math:`s_1 = s_2 = s_3`.
68
        center: If ``'image'``, rotations and scaling will be performed around
69
            the image center. If ``'origin'``, rotations and scaling will be
70
            performed around the origin in world coordinates.
71
        default_pad_value: As the image is rotated, some values near the
72
            borders will be undefined.
73
            If ``'minimum'``, the fill value will be the image minimum.
74
            If ``'mean'``, the fill value is the mean of the border values.
75
            If ``'otsu'``, the fill value is the mean of the values at the
76
            border that lie under an
77
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
78
            If it is a number, that value will be used.
79
        image_interpolation: See :ref:`Interpolation`.
80
        p: Probability that this transform will be applied.
81
        keys: See :py:class:`~torchio.transforms.Transform`.
82
83
    Example:
84
        >>> import torchio as tio
85
        >>> subject = tio.datasets.Colin27()
86
        >>> transform = tio.RandomAffine(
87
        ...     scales=(0.9, 1.2),
88
        ...     degrees=10,
89
        ...     isotropic=True,
90
        ...     image_interpolation='nearest',
91
        ... )
92
        >>> transformed = transform(subject)
93
94
    From the command line::
95
96
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "scales=(0.9, 1.2) degrees=10 isotropic=True image_interpolation=nearest" --seed 42 affine_min.nii.gz
97
98
    """
99
    def __init__(
100
            self,
101
            scales: TypeOneToSixFloat = 0.1,
102
            degrees: TypeOneToSixFloat = 10,
103
            translation: TypeOneToSixFloat = 0,
104
            isotropic: bool = False,
105
            center: str = 'image',
106
            default_pad_value: Union[str, float] = 'minimum',
107
            image_interpolation: str = 'linear',
108
            p: float = 1,
109
            keys: Optional[Sequence[str]] = None,
110
            ):
111
        super().__init__(p=p, keys=keys)
112
        self.isotropic = isotropic
113
        _parse_scales_isotropic(scales, isotropic)
114
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
115
        self.degrees = self.parse_params(degrees, 0, 'degrees')
116
        self.translation = self.parse_params(translation, 0, 'translation')
117
        if center not in ('image', 'origin'):
118
            message = (
119
                'Center argument must be "image" or "origin",'
120
                f' not "{center}"'
121
            )
122
            raise ValueError(message)
123
        self.center = center
124
        self.default_pad_value = _parse_default_value(default_pad_value)
125
        self.image_interpolation = self.parse_interpolation(image_interpolation)
126
127
    def get_params(
128
            self,
129
            scales: TypeSextetFloat,
130
            degrees: TypeSextetFloat,
131
            translation: TypeSextetFloat,
132
            isotropic: bool,
133
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
134
        scaling_params = self.sample_uniform_sextet(scales)
135
        if isotropic:
136
            scaling_params.fill_(scaling_params[0])
137
        rotation_params = self.sample_uniform_sextet(degrees)
138
        translation_params = self.sample_uniform_sextet(translation)
139
        return scaling_params, rotation_params, translation_params
140
141
    def apply_transform(self, subject: Subject) -> Subject:
142
        subject.check_consistent_spatial_shape()
143
        scaling_params, rotation_params, translation_params = self.get_params(
144
            self.scales,
145
            self.degrees,
146
            self.translation,
147
            self.isotropic,
148
        )
149
        arguments = dict(
150
            scales=scaling_params.tolist(),
151
            degrees=rotation_params.tolist(),
152
            translation=translation_params.tolist(),
153
            center=self.center,
154
            default_pad_value=self.default_pad_value,
155
            image_interpolation=self.image_interpolation,
156
        )
157
        transform = Affine(**arguments)
158
        transformed = transform(subject)
159
        return transformed
160
161
162
class Affine(SpatialTransform):
163
    r"""Apply affine transformation.
164
165
    Args:
166
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
167
            scaling values along each dimension.
168
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
169
            rotation around each axis.
170
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
171
            translation in mm along each axis.
172
        center: If ``'image'``, rotations and scaling will be performed around
173
            the image center. If ``'origin'``, rotations and scaling will be
174
            performed around the origin in world coordinates.
175
        default_pad_value: As the image is rotated, some values near the
176
            borders will be undefined.
177
            If ``'minimum'``, the fill value will be the image minimum.
178
            If ``'mean'``, the fill value is the mean of the border values.
179
            If ``'otsu'``, the fill value is the mean of the values at the
180
            border that lie under an
181
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
182
            If it is a number, that value will be used.
183
        image_interpolation: See :ref:`Interpolation`.
184
        keys: See :py:class:`~torchio.transforms.Transform`.
185
    """
186
    def __init__(
187
            self,
188
            scales: TypeTripletFloat,
189
            degrees: TypeTripletFloat,
190
            translation: TypeTripletFloat,
191
            center: str = 'image',
192
            default_pad_value: Union[str, float] = 'minimum',
193
            image_interpolation: str = 'linear',
194
            keys: Optional[Sequence[str]] = None,
195
            ):
196
        super().__init__(keys=keys)
197
        self.scales = self.parse_params(
198
            scales,
199
            None,
200
            'scales',
201
            make_ranges=False,
202
            min_constraint=0,
203
        )
204
        self.degrees = self.parse_params(
205
            degrees,
206
            None,
207
            'degrees',
208
            make_ranges=False,
209
        )
210
        self.translation = self.parse_params(
211
            translation,
212
            None,
213
            'translation',
214
            make_ranges=False,
215
        )
216
        if center not in ('image', 'origin'):
217
            message = (
218
                'Center argument must be "image" or "origin",'
219
                f' not "{center}"'
220
            )
221
            raise ValueError(message)
222
        self.center = center
223
        self.use_image_center = center == 'image'
224
        self.default_pad_value = _parse_default_value(default_pad_value)
225
        self.image_interpolation = self.parse_interpolation(image_interpolation)
226
        self.invert_transform = False
227
        self.args_names = (
228
            'scales',
229
            'degrees',
230
            'translation',
231
            'center',
232
            'default_pad_value',
233
            'image_interpolation',
234
        )
235
236
    @staticmethod
237
    def get_scaling_transform(
238
            scaling_params: Sequence[float],
239
            center_lps: Optional[TypeTripletFloat] = None,
240
            ) -> sitk.ScaleTransform:
241
        # scaling_params are inverted so that they are more intuitive
242
        # For example, 1.5 means the objects look 1.5 times larger
243
        transform = sitk.ScaleTransform(3)
244
        scaling_params = 1 / np.array(scaling_params)
245
        transform.SetScale(scaling_params)
246
        if center_lps is not None:
247
            transform.SetCenter(center_lps)
248
        return transform
249
250
    @staticmethod
251
    def get_rotation_transform(
252
            degrees: Sequence[float],
253
            translation: Sequence[float],
254
            center_lps: Optional[TypeTripletFloat] = None,
255
            ) -> sitk.Euler3DTransform:
256
        transform = sitk.Euler3DTransform()
257
        radians = np.radians(degrees)
258
        transform.SetRotation(*radians)
259
        transform.SetTranslation(translation)
260
        if center_lps is not None:
261
            transform.SetCenter(center_lps)
262
        return transform
263
264
    def apply_transform(self, subject: Subject) -> Subject:
265
        scaling_params = np.array(self.scales).copy()
266
        rotation_params = np.array(self.degrees).copy()
267
        translation_params = np.array(self.translation).copy()
268
        subject.check_consistent_spatial_shape()
269
        for image in self.get_images(subject):
270
            if image[TYPE] != INTENSITY:
271
                interpolation = 'nearest'
272
            else:
273
                interpolation = self.image_interpolation
274
275
            if image.is_2d():
276
                scaling_params[2] = 1
277
                rotation_params[:-1] = 0
278
279
            if self.use_image_center:
280
                center = image.get_center(lps=True)
281
            else:
282
                center = None
283
284
            transformed_tensors = []
285
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
286
                transformed_tensor = self.apply_affine_transform(
287
                    tensor,
288
                    image[AFFINE],
289
                    scaling_params.tolist(),
290
                    rotation_params.tolist(),
291
                    translation_params.tolist(),
292
                    interpolation,
293
                    center_lps=center,
294
                )
295
                transformed_tensors.append(transformed_tensor)
296
            image[DATA] = torch.stack(transformed_tensors)
297
        return subject
298
299
    def apply_affine_transform(
300
            self,
301
            tensor: torch.Tensor,
302
            affine: np.ndarray,
303
            scaling_params: Sequence[float],
304
            rotation_params: Sequence[float],
305
            translation_params: Sequence[float],
306
            interpolation: Interpolation,
307
            center_lps: Optional[TypeTripletFloat] = None,
308
            ) -> torch.Tensor:
309
        assert tensor.ndim == 3
310
311
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
312
        floating = reference = image
313
314
        scaling_transform = self.get_scaling_transform(
315
            scaling_params,
316
            center_lps=center_lps,
317
        )
318
        rotation_transform = self.get_rotation_transform(
319
            rotation_params,
320
            translation_params,
321
            center_lps=center_lps,
322
        )
323
324
        sitk_major_version = get_major_sitk_version()
325
        if sitk_major_version == 1:
326
            transform = sitk.Transform(3, sitk.sitkComposite)
327
            transform.AddTransform(scaling_transform)
328
            transform.AddTransform(rotation_transform)
329
        elif sitk_major_version == 2:
330
            transforms = [scaling_transform, rotation_transform]
331
            transform = sitk.CompositeTransform(transforms)
332
333
        if self.invert_transform:
334
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
335
336
        if self.default_pad_value == 'minimum':
337
            default_value = tensor.min().item()
338
        elif self.default_pad_value == 'mean':
339
            default_value = get_borders_mean(image, filter_otsu=False)
340
        elif self.default_pad_value == 'otsu':
341
            default_value = get_borders_mean(image, filter_otsu=True)
342
        else:
343
            default_value = self.default_pad_value
344
345
        resampler = sitk.ResampleImageFilter()
346
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
347
        resampler.SetReferenceImage(reference)
348
        resampler.SetDefaultPixelValue(float(default_value))
349
        resampler.SetOutputPixelType(sitk.sitkFloat32)
350
        resampler.SetTransform(transform)
351
        resampled = resampler.Execute(floating)
352
353
        np_array = sitk.GetArrayFromImage(resampled)
354
        np_array = np_array.transpose()  # ITK to NumPy
355
        tensor = torch.from_numpy(np_array)
356
        return tensor
357
358
359
# flake8: noqa: E201, E203, E243
360
def get_borders_mean(image, filter_otsu=True):
361
    # pylint: disable=bad-whitespace
362
    array = sitk.GetArrayViewFromImage(image)
363
    borders_tuple = (
364
        array[ 0,  :,  :],
365
        array[-1,  :,  :],
366
        array[ :,  0,  :],
367
        array[ :, -1,  :],
368
        array[ :,  :,  0],
369
        array[ :,  :, -1],
370
    )
371
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
372
    if not filter_otsu:
373
        return borders_flat.mean()
374
    borders_reshaped = borders_flat.reshape(1, 1, -1)
375
    borders_image = sitk.GetImageFromArray(borders_reshaped)
376
    otsu = sitk.OtsuThresholdImageFilter()
377
    otsu.Execute(borders_image)
378
    threshold = otsu.GetThreshold()
379
    values = borders_flat[borders_flat < threshold]
380
    if values.any():
381
        default_value = values.mean()
382
    else:
383
        default_value = borders_flat.mean()
384
    return default_value
385
386
def _parse_scales_isotropic(scales, isotropic):
387
    params = to_tuple(scales)
388
    if isotropic and len(scales) in (3, 6):
389
        message = (
390
            'If "isotropic" is True, the value for "scales" must have'
391
            f' length 1 or 2, but "{scales}" was passed'
392
        )
393
        raise ValueError(message)
394
395
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
396
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
397
        return value
398
    message = (
399
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
400
        ' or a number'
401
    )
402
    raise ValueError(message)
403