Passed
Pull Request — master (#353)
by Fernando
01:09
created

Affine.__init__()   A

Complexity

Conditions 2

Size

Total Lines 41
Code Lines 35

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 35
nop 8
dl 0
loc 41
rs 9.0399
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, List, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import Interpolation, get_sitk_interpolator
22
from .. import RandomTransform
23
24
25
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
26
27
28
class RandomAffine(RandomTransform, SpatialTransform):
29
    r"""Random affine transformation.
30
31
    Args:
32
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
33
            scaling ranges.
34
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
35
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
36
            If two values :math:`(a, b)` are provided,
37
            then :math:`s_i \sim \mathcal{U}(a, b)`.
38
            If only one value :math:`x` is provided,
39
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
43
            making the objects inside look twice as small while preserving
44
            the physical size and position of the image bounds.
45
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
46
            rotation ranges in degrees.
47
            Rotation angles around each axis are
48
            :math:`(\theta_1, \theta_2, \theta_3)`,
49
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
50
            If two values :math:`(a, b)` are provided,
51
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
52
            If only one value :math:`x` is provided,
53
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
54
            If three values :math:`(x_1, x_2, x_3)` are provided,
55
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
56
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
57
            translation ranges in mm.
58
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
59
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
60
            If two values :math:`(a, b)` are provided,
61
            then :math:`t_i \sim \mathcal{U}(a, b)`.
62
            If only one value :math:`x` is provided,
63
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
64
            If three values :math:`(x_1, x_2, x_3)` are provided,
65
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
66
        isotropic: If ``True``, the scaling factor along all dimensions is the
67
            same, i.e. :math:`s_1 = s_2 = s_3`.
68
        center: If ``'image'``, rotations and scaling will be performed around
69
            the image center. If ``'origin'``, rotations and scaling will be
70
            performed around the origin in world coordinates.
71
        default_pad_value: As the image is rotated, some values near the
72
            borders will be undefined.
73
            If ``'minimum'``, the fill value will be the image minimum.
74
            If ``'mean'``, the fill value is the mean of the border values.
75
            If ``'otsu'``, the fill value is the mean of the values at the
76
            border that lie under an
77
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
78
            If it is a number, that value will be used.
79
        image_interpolation: See :ref:`Interpolation`.
80
        p: Probability that this transform will be applied.
81
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
82
        keys: See :py:class:`~torchio.transforms.Transform`.
83
84
    Example:
85
        >>> import torchio as tio
86
        >>> subject = tio.datasets.Colin27()
87
        >>> transform = tio.RandomAffine(
88
        ...     scales=(0.9, 1.2),
89
        ...     degrees=(10),
90
        ...     isotropic=False,
91
        ...     default_pad_value='otsu',
92
        ...     image_interpolation='bspline',
93
        ... )
94
        >>> transformed = transform(subject)
95
96
    From the command line::
97
98
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "degrees=30 default_pad_value=minimum" --seed 42 affine_min.nii.gz
99
100
    """
101
    def __init__(
102
            self,
103
            scales: TypeOneToSixFloat = 0.1,
104
            degrees: TypeOneToSixFloat = 10,
105
            translation: TypeOneToSixFloat = 0,
106
            isotropic: bool = False,
107
            center: str = 'image',
108
            default_pad_value: Union[str, float] = 'otsu',
109
            image_interpolation: str = 'linear',
110
            p: float = 1,
111
            keys: Optional[List[str]] = None,
112
            ):
113
        super().__init__(p=p, keys=keys)
114
        self.isotropic = isotropic
115
        _parse_scales_isotropic(scales, isotropic)
116
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
117
        self.degrees = self.parse_params(degrees, 0, 'degrees')
118
        self.translation = self.parse_params(translation, 0, 'translation')
119
        if center not in ('image', 'origin'):
120
            message = (
121
                'Center argument must be "image" or "origin",'
122
                f' not "{center}"'
123
            )
124
            raise ValueError(message)
125
        self.center = center
126
        self.default_pad_value = _parse_default_value(default_pad_value)
127
        self.interpolation = self.parse_interpolation(image_interpolation)
128
129
    def get_params(
130
            self,
131
            scales: TypeSextetFloat,
132
            degrees: TypeSextetFloat,
133
            translation: TypeSextetFloat,
134
            isotropic: bool,
135
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
136
        scaling_params = self.sample_uniform_sextet(scales)
137
        if isotropic:
138
            scaling_params.fill_(scaling_params[0])
139
        rotation_params = self.sample_uniform_sextet(degrees)
140
        translation_params = self.sample_uniform_sextet(translation)
141
        return scaling_params, rotation_params, translation_params
142
143
    def apply_transform(self, subject: Subject) -> Subject:
144
        subject.check_consistent_spatial_shape()
145
        scaling_params, rotation_params, translation_params = self.get_params(
146
            self.scales,
147
            self.degrees,
148
            self.translation,
149
            self.isotropic,
150
        )
151
        arguments = dict(
152
            scales=scaling_params.tolist(),
153
            degrees=rotation_params.tolist(),
154
            translation=translation_params.tolist(),
155
            center=self.center,
156
            default_pad_value=self.default_pad_value,
157
            image_interpolation=self.interpolation,
158
        )
159
        transform = Affine(**arguments)
160
        transformed = transform(subject)
161
        return transformed
162
163
164
class Affine(SpatialTransform):
165
    r"""Apply affine transformation.
166
167
    Args:
168
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
169
            scaling values along each dimension.
170
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
171
            rotation around each axis.
172
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
173
            translation in mm along each axis.
174
        center: If ``'image'``, rotations and scaling will be performed around
175
            the image center. If ``'origin'``, rotations and scaling will be
176
            performed around the origin in world coordinates.
177
        default_pad_value: As the image is rotated, some values near the
178
            borders will be undefined.
179
            If ``'minimum'``, the fill value will be the image minimum.
180
            If ``'mean'``, the fill value is the mean of the border values.
181
            If ``'otsu'``, the fill value is the mean of the values at the
182
            border that lie under an
183
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
184
            If it is a number, that value will be used.
185
        image_interpolation: See :ref:`Interpolation`.
186
        keys: See :py:class:`~torchio.transforms.Transform`.
187
    """
188
    def __init__(
189
            self,
190
            scales: TypeTripletFloat,
191
            degrees: TypeTripletFloat,
192
            translation: TypeTripletFloat,
193
            center: str = 'image',
194
            default_pad_value: Union[str, float] = 'otsu',
195
            image_interpolation: str = 'linear',
196
            keys: Optional[List[str]] = None,
197
            ):
198
        super().__init__(keys=keys)
199
        self.scales = self.parse_params(
200
            scales,
201
            None,
202
            'scales',
203
            make_ranges=False,
204
            min_constraint=0,
205
        )
206
        self.degrees = self.parse_params(
207
            degrees,
208
            None,
209
            'degrees',
210
            make_ranges=False,
211
        )
212
        self.translation = self.parse_params(
213
            translation,
214
            None,
215
            'translation',
216
            make_ranges=False,
217
        )
218
        if center not in ('image', 'origin'):
219
            message = (
220
                'Center argument must be "image" or "origin",'
221
                f' not "{center}"'
222
            )
223
            raise ValueError(message)
224
        self.center = center
225
        self.use_image_center = center == 'image'
226
        self.default_pad_value = _parse_default_value(default_pad_value)
227
        self.interpolation = self.parse_interpolation(image_interpolation)
228
        self.invert_transform = False
229
230
    def get_arguments(self):
231
        arguments = dict(
232
            scales=self.scales,
233
            degrees=self.degrees,
234
            translation=self.translation,
235
            center=self.center,
236
            default_pad_value=self.default_pad_value,
237
            image_interpolation=self.interpolation,
238
        )
239
        return arguments
240
241
    @staticmethod
242
    def get_scaling_transform(
243
            scaling_params: List[float],
244
            center_lps: Optional[TypeTripletFloat] = None,
245
            ) -> sitk.ScaleTransform:
246
        # scaling_params are inverted so that they are more intuitive
247
        # For example, 1.5 means the objects look 1.5 times larger
248
        transform = sitk.ScaleTransform(3)
249
        scaling_params = 1 / np.array(scaling_params)
250
        transform.SetScale(scaling_params)
251
        if center_lps is not None:
252
            transform.SetCenter(center_lps)
253
        return transform
254
255
    @staticmethod
256
    def get_rotation_transform(
257
            degrees: List[float],
258
            translation: List[float],
259
            center_lps: Optional[TypeTripletFloat] = None,
260
            ) -> sitk.Euler3DTransform:
261
        transform = sitk.Euler3DTransform()
262
        radians = np.radians(degrees)
263
        transform.SetRotation(*radians)
264
        transform.SetTranslation(translation)
265
        if center_lps is not None:
266
            transform.SetCenter(center_lps)
267
        return transform
268
269
    def apply_transform(self, subject: Subject) -> Subject:
270
        scaling_params = np.array(self.scales).copy()
271
        rotation_params = np.array(self.degrees).copy()
272
        translation_params = np.array(self.translation).copy()
273
        subject.check_consistent_spatial_shape()
274
        for image in self.get_images(subject):
275
            if image[TYPE] != INTENSITY:
276
                interpolation = Interpolation.NEAREST
277
            else:
278
                interpolation = self.interpolation
279
280
            if image.is_2d():
281
                scaling_params[2] = 1
282
                rotation_params[:-1] = 0
283
284
            if self.use_image_center:
285
                center = image.get_center(lps=True)
286
            else:
287
                center = None
288
289
            transformed_tensors = []
290
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
291
                transformed_tensor = self.apply_affine_transform(
292
                    tensor,
293
                    image[AFFINE],
294
                    scaling_params.tolist(),
295
                    rotation_params.tolist(),
296
                    translation_params.tolist(),
297
                    interpolation,
298
                    center_lps=center,
299
                )
300
                transformed_tensors.append(transformed_tensor)
301
            image[DATA] = torch.stack(transformed_tensors)
302
        return subject
303
304
    def apply_affine_transform(
305
            self,
306
            tensor: torch.Tensor,
307
            affine: np.ndarray,
308
            scaling_params: List[float],
309
            rotation_params: List[float],
310
            translation_params: List[float],
311
            interpolation: Interpolation,
312
            center_lps: Optional[TypeTripletFloat] = None,
313
            ) -> torch.Tensor:
314
        assert tensor.ndim == 3
315
316
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
317
        floating = reference = image
318
319
        scaling_transform = self.get_scaling_transform(
320
            scaling_params,
321
            center_lps=center_lps,
322
        )
323
        rotation_transform = self.get_rotation_transform(
324
            rotation_params,
325
            translation_params,
326
            center_lps=center_lps,
327
        )
328
329
        sitk_major_version = get_major_sitk_version()
330
        if sitk_major_version == 1:
331
            transform = sitk.Transform(3, sitk.sitkComposite)
332
            transform.AddTransform(scaling_transform)
333
            transform.AddTransform(rotation_transform)
334
        elif sitk_major_version == 2:
335
            transforms = [scaling_transform, rotation_transform]
336
            transform = sitk.CompositeTransform(transforms)
337
338
        if self.invert_transform:
339
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
340
341
        if self.default_pad_value == 'minimum':
342
            default_value = tensor.min().item()
343
        elif self.default_pad_value == 'mean':
344
            default_value = get_borders_mean(image, filter_otsu=False)
345
        elif self.default_pad_value == 'otsu':
346
            default_value = get_borders_mean(image, filter_otsu=True)
347
        else:
348
            default_value = self.default_pad_value
349
350
        resampler = sitk.ResampleImageFilter()
351
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
352
        resampler.SetReferenceImage(reference)
353
        resampler.SetDefaultPixelValue(float(default_value))
354
        resampler.SetOutputPixelType(sitk.sitkFloat32)
355
        resampler.SetTransform(transform)
356
        resampled = resampler.Execute(floating)
357
358
        np_array = sitk.GetArrayFromImage(resampled)
359
        np_array = np_array.transpose()  # ITK to NumPy
360
        tensor = torch.from_numpy(np_array)
361
        return tensor
362
363
364
# flake8: noqa: E201, E203, E243
365
def get_borders_mean(image, filter_otsu=True):
366
    # pylint: disable=bad-whitespace
367
    array = sitk.GetArrayViewFromImage(image)
368
    borders_tuple = (
369
        array[ 0,  :,  :],
370
        array[-1,  :,  :],
371
        array[ :,  0,  :],
372
        array[ :, -1,  :],
373
        array[ :,  :,  0],
374
        array[ :,  :, -1],
375
    )
376
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
377
    if not filter_otsu:
378
        return borders_flat.mean()
379
    borders_reshaped = borders_flat.reshape(1, 1, -1)
380
    borders_image = sitk.GetImageFromArray(borders_reshaped)
381
    otsu = sitk.OtsuThresholdImageFilter()
382
    otsu.Execute(borders_image)
383
    threshold = otsu.GetThreshold()
384
    values = borders_flat[borders_flat < threshold]
385
    if values.any():
386
        default_value = values.mean()
387
    else:
388
        default_value = borders_flat.mean()
389
    return default_value
390
391
def _parse_scales_isotropic(scales, isotropic):
392
    params = to_tuple(scales)
393
    if isotropic and len(scales) in (3, 6):
394
        message = (
395
            'If "isotropic" is True, the value for "scales" must have'
396
            f' length 1 or 2, but "{scales}" was passed'
397
        )
398
        raise ValueError(message)
399
400
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
401
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
402
        return value
403
    message = (
404
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
405
        ' or a number'
406
    )
407
    raise ValueError(message)
408