Passed
Pull Request — master (#353)
by Fernando
01:11
created

Affine.__init__()   B

Complexity

Conditions 2

Size

Total Lines 48
Code Lines 42

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 42
nop 8
dl 0
loc 48
rs 8.872
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, Sequence, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import RandomTransform
22
23
24
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
25
26
27
class RandomAffine(RandomTransform, SpatialTransform):
28
    r"""Random affine transformation.
29
30
    Args:
31
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
32
            scaling ranges.
33
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
34
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
35
            If two values :math:`(a, b)` are provided,
36
            then :math:`s_i \sim \mathcal{U}(a, b)`.
37
            If only one value :math:`x` is provided,
38
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
39
            If three values :math:`(x_1, x_2, x_3)` are provided,
40
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
41
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
42
            making the objects inside look twice as small while preserving
43
            the physical size and position of the image bounds.
44
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
45
            rotation ranges in degrees.
46
            Rotation angles around each axis are
47
            :math:`(\theta_1, \theta_2, \theta_3)`,
48
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
49
            If two values :math:`(a, b)` are provided,
50
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
51
            If only one value :math:`x` is provided,
52
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
53
            If three values :math:`(x_1, x_2, x_3)` are provided,
54
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
55
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
56
            translation ranges in mm.
57
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
58
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
59
            If two values :math:`(a, b)` are provided,
60
            then :math:`t_i \sim \mathcal{U}(a, b)`.
61
            If only one value :math:`x` is provided,
62
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
63
            If three values :math:`(x_1, x_2, x_3)` are provided,
64
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
65
        isotropic: If ``True``, the scaling factor along all dimensions is the
66
            same, i.e. :math:`s_1 = s_2 = s_3`.
67
        center: If ``'image'``, rotations and scaling will be performed around
68
            the image center. If ``'origin'``, rotations and scaling will be
69
            performed around the origin in world coordinates.
70
        default_pad_value: As the image is rotated, some values near the
71
            borders will be undefined.
72
            If ``'minimum'``, the fill value will be the image minimum.
73
            If ``'mean'``, the fill value is the mean of the border values.
74
            If ``'otsu'``, the fill value is the mean of the values at the
75
            border that lie under an
76
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
77
            If it is a number, that value will be used.
78
        image_interpolation: See :ref:`Interpolation`.
79
        p: Probability that this transform will be applied.
80
        keys: See :py:class:`~torchio.transforms.Transform`.
81
82
    Example:
83
        >>> import torchio as tio
84
        >>> subject = tio.datasets.Colin27()
85
        >>> transform = tio.RandomAffine(
86
        ...     scales=(0.9, 1.2),
87
        ...     degrees=10,
88
        ...     isotropic=True,
89
        ...     image_interpolation='nearest',
90
        ... )
91
        >>> transformed = transform(subject)
92
93
    From the command line::
94
95
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "scales=(0.9, 1.2) degrees=10 isotropic=True image_interpolation=nearest" --seed 42 affine_min.nii.gz
96
97
    """
98
    def __init__(
99
            self,
100
            scales: TypeOneToSixFloat = 0.1,
101
            degrees: TypeOneToSixFloat = 10,
102
            translation: TypeOneToSixFloat = 0,
103
            isotropic: bool = False,
104
            center: str = 'image',
105
            default_pad_value: Union[str, float] = 'minimum',
106
            image_interpolation: str = 'linear',
107
            p: float = 1,
108
            keys: Optional[Sequence[str]] = None,
109
            ):
110
        super().__init__(p=p, keys=keys)
111
        self.isotropic = isotropic
112
        _parse_scales_isotropic(scales, isotropic)
113
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
114
        self.degrees = self.parse_params(degrees, 0, 'degrees')
115
        self.translation = self.parse_params(translation, 0, 'translation')
116
        if center not in ('image', 'origin'):
117
            message = (
118
                'Center argument must be "image" or "origin",'
119
                f' not "{center}"'
120
            )
121
            raise ValueError(message)
122
        self.center = center
123
        self.default_pad_value = _parse_default_value(default_pad_value)
124
        self.image_interpolation = self.parse_interpolation(image_interpolation)
125
126
    def get_params(
127
            self,
128
            scales: TypeSextetFloat,
129
            degrees: TypeSextetFloat,
130
            translation: TypeSextetFloat,
131
            isotropic: bool,
132
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
133
        scaling_params = self.sample_uniform_sextet(scales)
134
        if isotropic:
135
            scaling_params.fill_(scaling_params[0])
136
        rotation_params = self.sample_uniform_sextet(degrees)
137
        translation_params = self.sample_uniform_sextet(translation)
138
        return scaling_params, rotation_params, translation_params
139
140
    def apply_transform(self, subject: Subject) -> Subject:
141
        subject.check_consistent_spatial_shape()
142
        scaling_params, rotation_params, translation_params = self.get_params(
143
            self.scales,
144
            self.degrees,
145
            self.translation,
146
            self.isotropic,
147
        )
148
        arguments = dict(
149
            scales=scaling_params.tolist(),
150
            degrees=rotation_params.tolist(),
151
            translation=translation_params.tolist(),
152
            center=self.center,
153
            default_pad_value=self.default_pad_value,
154
            image_interpolation=self.image_interpolation,
155
        )
156
        transform = Affine(**arguments)
157
        transformed = transform(subject)
158
        return transformed
159
160
161
class Affine(SpatialTransform):
162
    r"""Apply affine transformation.
163
164
    Args:
165
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
166
            scaling values along each dimension.
167
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
168
            rotation around each axis.
169
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
170
            translation in mm along each axis.
171
        center: If ``'image'``, rotations and scaling will be performed around
172
            the image center. If ``'origin'``, rotations and scaling will be
173
            performed around the origin in world coordinates.
174
        default_pad_value: As the image is rotated, some values near the
175
            borders will be undefined.
176
            If ``'minimum'``, the fill value will be the image minimum.
177
            If ``'mean'``, the fill value is the mean of the border values.
178
            If ``'otsu'``, the fill value is the mean of the values at the
179
            border that lie under an
180
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
181
            If it is a number, that value will be used.
182
        image_interpolation: See :ref:`Interpolation`.
183
        keys: See :py:class:`~torchio.transforms.Transform`.
184
    """
185
    def __init__(
186
            self,
187
            scales: TypeTripletFloat,
188
            degrees: TypeTripletFloat,
189
            translation: TypeTripletFloat,
190
            center: str = 'image',
191
            default_pad_value: Union[str, float] = 'minimum',
192
            image_interpolation: str = 'linear',
193
            keys: Optional[Sequence[str]] = None,
194
            ):
195
        super().__init__(keys=keys)
196
        self.scales = self.parse_params(
197
            scales,
198
            None,
199
            'scales',
200
            make_ranges=False,
201
            min_constraint=0,
202
        )
203
        self.degrees = self.parse_params(
204
            degrees,
205
            None,
206
            'degrees',
207
            make_ranges=False,
208
        )
209
        self.translation = self.parse_params(
210
            translation,
211
            None,
212
            'translation',
213
            make_ranges=False,
214
        )
215
        if center not in ('image', 'origin'):
216
            message = (
217
                'Center argument must be "image" or "origin",'
218
                f' not "{center}"'
219
            )
220
            raise ValueError(message)
221
        self.center = center
222
        self.use_image_center = center == 'image'
223
        self.default_pad_value = _parse_default_value(default_pad_value)
224
        self.image_interpolation = self.parse_interpolation(image_interpolation)
225
        self.invert_transform = False
226
        self.args_names = (
227
            'scales',
228
            'degrees',
229
            'translation',
230
            'center',
231
            'default_pad_value',
232
            'image_interpolation',
233
        )
234
235
    @staticmethod
236
    def get_scaling_transform(
237
            scaling_params: Sequence[float],
238
            center_lps: Optional[TypeTripletFloat] = None,
239
            ) -> sitk.ScaleTransform:
240
        # scaling_params are inverted so that they are more intuitive
241
        # For example, 1.5 means the objects look 1.5 times larger
242
        transform = sitk.ScaleTransform(3)
243
        scaling_params = 1 / np.array(scaling_params)
244
        transform.SetScale(scaling_params)
245
        if center_lps is not None:
246
            transform.SetCenter(center_lps)
247
        return transform
248
249
    @staticmethod
250
    def get_rotation_transform(
251
            degrees: Sequence[float],
252
            translation: Sequence[float],
253
            center_lps: Optional[TypeTripletFloat] = None,
254
            ) -> sitk.Euler3DTransform:
255
        transform = sitk.Euler3DTransform()
256
        radians = np.radians(degrees)
257
        transform.SetRotation(*radians)
258
        transform.SetTranslation(translation)
259
        if center_lps is not None:
260
            transform.SetCenter(center_lps)
261
        return transform
262
263
    def apply_transform(self, subject: Subject) -> Subject:
264
        scaling_params = np.array(self.scales).copy()
265
        rotation_params = np.array(self.degrees).copy()
266
        translation_params = np.array(self.translation).copy()
267
        subject.check_consistent_spatial_shape()
268
        for image in self.get_images(subject):
269
            if image[TYPE] != INTENSITY:
270
                interpolation = 'nearest'
271
            else:
272
                interpolation = self.image_interpolation
273
274
            if image.is_2d():
275
                scaling_params[2] = 1
276
                rotation_params[:-1] = 0
277
278
            if self.use_image_center:
279
                center = image.get_center(lps=True)
280
            else:
281
                center = None
282
283
            transformed_tensors = []
284
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
285
                transformed_tensor = self.apply_affine_transform(
286
                    tensor,
287
                    image[AFFINE],
288
                    scaling_params.tolist(),
289
                    rotation_params.tolist(),
290
                    translation_params.tolist(),
291
                    interpolation,
292
                    center_lps=center,
293
                )
294
                transformed_tensors.append(transformed_tensor)
295
            image[DATA] = torch.stack(transformed_tensors)
296
        return subject
297
298
    def apply_affine_transform(
299
            self,
300
            tensor: torch.Tensor,
301
            affine: np.ndarray,
302
            scaling_params: Sequence[float],
303
            rotation_params: Sequence[float],
304
            translation_params: Sequence[float],
305
            interpolation: str,
306
            center_lps: Optional[TypeTripletFloat] = None,
307
            ) -> torch.Tensor:
308
        assert tensor.ndim == 3
309
310
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
311
        floating = reference = image
312
313
        scaling_transform = self.get_scaling_transform(
314
            scaling_params,
315
            center_lps=center_lps,
316
        )
317
        rotation_transform = self.get_rotation_transform(
318
            rotation_params,
319
            translation_params,
320
            center_lps=center_lps,
321
        )
322
323
        sitk_major_version = get_major_sitk_version()
324
        if sitk_major_version == 1:
325
            transform = sitk.Transform(3, sitk.sitkComposite)
326
            transform.AddTransform(scaling_transform)
327
            transform.AddTransform(rotation_transform)
328
        elif sitk_major_version == 2:
329
            transforms = [scaling_transform, rotation_transform]
330
            transform = sitk.CompositeTransform(transforms)
331
332
        if self.invert_transform:
333
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
334
335
        if self.default_pad_value == 'minimum':
336
            default_value = tensor.min().item()
337
        elif self.default_pad_value == 'mean':
338
            default_value = get_borders_mean(image, filter_otsu=False)
339
        elif self.default_pad_value == 'otsu':
340
            default_value = get_borders_mean(image, filter_otsu=True)
341
        else:
342
            default_value = self.default_pad_value
343
344
        resampler = sitk.ResampleImageFilter()
345
        resampler.SetInterpolator(self.get_sitk_interpolator(interpolation))
346
        resampler.SetReferenceImage(reference)
347
        resampler.SetDefaultPixelValue(float(default_value))
348
        resampler.SetOutputPixelType(sitk.sitkFloat32)
349
        resampler.SetTransform(transform)
350
        resampled = resampler.Execute(floating)
351
352
        np_array = sitk.GetArrayFromImage(resampled)
353
        np_array = np_array.transpose()  # ITK to NumPy
354
        tensor = torch.from_numpy(np_array)
355
        return tensor
356
357
358
# flake8: noqa: E201, E203, E243
359
def get_borders_mean(image, filter_otsu=True):
360
    # pylint: disable=bad-whitespace
361
    array = sitk.GetArrayViewFromImage(image)
362
    borders_tuple = (
363
        array[ 0,  :,  :],
364
        array[-1,  :,  :],
365
        array[ :,  0,  :],
366
        array[ :, -1,  :],
367
        array[ :,  :,  0],
368
        array[ :,  :, -1],
369
    )
370
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
371
    if not filter_otsu:
372
        return borders_flat.mean()
373
    borders_reshaped = borders_flat.reshape(1, 1, -1)
374
    borders_image = sitk.GetImageFromArray(borders_reshaped)
375
    otsu = sitk.OtsuThresholdImageFilter()
376
    otsu.Execute(borders_image)
377
    threshold = otsu.GetThreshold()
378
    values = borders_flat[borders_flat < threshold]
379
    if values.any():
380
        default_value = values.mean()
381
    else:
382
        default_value = borders_flat.mean()
383
    return default_value
384
385
def _parse_scales_isotropic(scales, isotropic):
386
    params = to_tuple(scales)
387
    if isotropic and len(scales) in (3, 6):
388
        message = (
389
            'If "isotropic" is True, the value for "scales" must have'
390
            f' length 1 or 2, but "{scales}" was passed'
391
        )
392
        raise ValueError(message)
393
394
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
395
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
396
        return value
397
    message = (
398
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
399
        ' or a number'
400
    )
401
    raise ValueError(message)
402