Passed
Push — master ( e28dfa...5bce14 )
by Fernando
04:12
created

Affine.__init__()   B

Complexity

Conditions 2

Size

Total Lines 48
Code Lines 42

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 42
nop 8
dl 0
loc 48
rs 8.872
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from numbers import Number
2
from typing import Tuple, Optional, Sequence, Union
3
4
import torch
5
import numpy as np
6
import SimpleITK as sitk
7
8
from ....data.io import nib_to_sitk
9
from ....data.subject import Subject
10
from ....constants import INTENSITY, TYPE
11
from ....utils import get_major_sitk_version, to_tuple
12
from ....typing import TypeRangeFloat, TypeSextetFloat, TypeTripletFloat
13
from ... import SpatialTransform
14
from .. import RandomTransform
15
16
17
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
18
19
20
class RandomAffine(RandomTransform, SpatialTransform):
21
    r"""Apply a random affine transformation and resample the image.
22
23
    Args:
24
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
25
            scaling ranges.
26
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
27
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
28
            If two values :math:`(a, b)` are provided,
29
            then :math:`s_i \sim \mathcal{U}(a, b)`.
30
            If only one value :math:`x` is provided,
31
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
32
            If three values :math:`(x_1, x_2, x_3)` are provided,
33
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
34
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
35
            making the objects inside look twice as small while preserving
36
            the physical size and position of the image bounds.
37
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
38
            rotation ranges in degrees.
39
            Rotation angles around each axis are
40
            :math:`(\theta_1, \theta_2, \theta_3)`,
41
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
42
            If two values :math:`(a, b)` are provided,
43
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
44
            If only one value :math:`x` is provided,
45
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
46
            If three values :math:`(x_1, x_2, x_3)` are provided,
47
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
48
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
49
            translation ranges in mm.
50
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
51
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
52
            If two values :math:`(a, b)` are provided,
53
            then :math:`t_i \sim \mathcal{U}(a, b)`.
54
            If only one value :math:`x` is provided,
55
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
56
            If three values :math:`(x_1, x_2, x_3)` are provided,
57
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
58
        isotropic: If ``True``, the scaling factor along all dimensions is the
59
            same, i.e. :math:`s_1 = s_2 = s_3`.
60
        center: If ``'image'``, rotations and scaling will be performed around
61
            the image center. If ``'origin'``, rotations and scaling will be
62
            performed around the origin in world coordinates.
63
        default_pad_value: As the image is rotated, some values near the
64
            borders will be undefined.
65
            If ``'minimum'``, the fill value will be the image minimum.
66
            If ``'mean'``, the fill value is the mean of the border values.
67
            If ``'otsu'``, the fill value is the mean of the values at the
68
            border that lie under an
69
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
70
            If it is a number, that value will be used.
71
        image_interpolation: See :ref:`Interpolation`.
72
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
73
            keyword arguments.
74
75
    Example:
76
        >>> import torchio as tio
77
        >>> subject = tio.datasets.Colin27()
78
        >>> transform = tio.RandomAffine(
79
        ...     scales=(0.9, 1.2),
80
        ...     degrees=10,
81
        ...     isotropic=True,
82
        ...     image_interpolation='nearest',
83
        ... )
84
        >>> transformed = transform(subject)
85
86
    From the command line::
87
88
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "scales=(0.9, 1.2) degrees=10 isotropic=True image_interpolation=nearest" --seed 42 affine_min.nii.gz
89
90
    """
91
    def __init__(
92
            self,
93
            scales: TypeOneToSixFloat = 0.1,
94
            degrees: TypeOneToSixFloat = 10,
95
            translation: TypeOneToSixFloat = 0,
96
            isotropic: bool = False,
97
            center: str = 'image',
98
            default_pad_value: Union[str, float] = 'minimum',
99
            image_interpolation: str = 'linear',
100
            **kwargs
101
            ):
102
        super().__init__(**kwargs)
103
        self.isotropic = isotropic
104
        _parse_scales_isotropic(scales, isotropic)
105
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
106
        self.degrees = self.parse_params(degrees, 0, 'degrees')
107
        self.translation = self.parse_params(translation, 0, 'translation')
108
        if center not in ('image', 'origin'):
109
            message = (
110
                'Center argument must be "image" or "origin",'
111
                f' not "{center}"'
112
            )
113
            raise ValueError(message)
114
        self.center = center
115
        self.default_pad_value = _parse_default_value(default_pad_value)
116
        self.image_interpolation = self.parse_interpolation(image_interpolation)
117
118
    def get_params(
119
            self,
120
            scales: TypeSextetFloat,
121
            degrees: TypeSextetFloat,
122
            translation: TypeSextetFloat,
123
            isotropic: bool,
124
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125
        scaling_params = self.sample_uniform_sextet(scales)
126
        if isotropic:
127
            scaling_params.fill_(scaling_params[0])
128
        rotation_params = self.sample_uniform_sextet(degrees)
129
        translation_params = self.sample_uniform_sextet(translation)
130
        return scaling_params, rotation_params, translation_params
131
132
    def apply_transform(self, subject: Subject) -> Subject:
133
        subject.check_consistent_spatial_shape()
134
        scaling_params, rotation_params, translation_params = self.get_params(
135
            self.scales,
136
            self.degrees,
137
            self.translation,
138
            self.isotropic,
139
        )
140
        arguments = dict(
141
            scales=scaling_params.tolist(),
142
            degrees=rotation_params.tolist(),
143
            translation=translation_params.tolist(),
144
            center=self.center,
145
            default_pad_value=self.default_pad_value,
146
            image_interpolation=self.image_interpolation,
147
        )
148
        transform = Affine(**self.add_include_exclude(arguments))
149
        transformed = transform(subject)
150
        return transformed
151
152
153
class Affine(SpatialTransform):
154
    r"""Apply affine transformation.
155
156
    Args:
157
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
158
            scaling values along each dimension.
159
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
160
            rotation around each axis.
161
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
162
            translation in mm along each axis.
163
        center: If ``'image'``, rotations and scaling will be performed around
164
            the image center. If ``'origin'``, rotations and scaling will be
165
            performed around the origin in world coordinates.
166
        default_pad_value: As the image is rotated, some values near the
167
            borders will be undefined.
168
            If ``'minimum'``, the fill value will be the image minimum.
169
            If ``'mean'``, the fill value is the mean of the border values.
170
            If ``'otsu'``, the fill value is the mean of the values at the
171
            border that lie under an
172
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
173
            If it is a number, that value will be used.
174
        image_interpolation: See :ref:`Interpolation`.
175
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
176
            keyword arguments.
177
    """
178
    def __init__(
179
            self,
180
            scales: TypeTripletFloat,
181
            degrees: TypeTripletFloat,
182
            translation: TypeTripletFloat,
183
            center: str = 'image',
184
            default_pad_value: Union[str, float] = 'minimum',
185
            image_interpolation: str = 'linear',
186
            **kwargs
187
            ):
188
        super().__init__(**kwargs)
189
        self.scales = self.parse_params(
190
            scales,
191
            None,
192
            'scales',
193
            make_ranges=False,
194
            min_constraint=0,
195
        )
196
        self.degrees = self.parse_params(
197
            degrees,
198
            None,
199
            'degrees',
200
            make_ranges=False,
201
        )
202
        self.translation = self.parse_params(
203
            translation,
204
            None,
205
            'translation',
206
            make_ranges=False,
207
        )
208
        if center not in ('image', 'origin'):
209
            message = (
210
                'Center argument must be "image" or "origin",'
211
                f' not "{center}"'
212
            )
213
            raise ValueError(message)
214
        self.center = center
215
        self.use_image_center = center == 'image'
216
        self.default_pad_value = _parse_default_value(default_pad_value)
217
        self.image_interpolation = self.parse_interpolation(image_interpolation)
218
        self.invert_transform = False
219
        self.args_names = (
220
            'scales',
221
            'degrees',
222
            'translation',
223
            'center',
224
            'default_pad_value',
225
            'image_interpolation',
226
        )
227
228
    @staticmethod
229
    def _get_scaling_transform(
230
            scaling_params: Sequence[float],
231
            center_lps: Optional[TypeTripletFloat] = None,
232
            ) -> sitk.ScaleTransform:
233
        # scaling_params are inverted so that they are more intuitive
234
        # For example, 1.5 means the objects look 1.5 times larger
235
        transform = sitk.ScaleTransform(3)
236
        scaling_params = 1 / np.array(scaling_params)
237
        transform.SetScale(scaling_params)
238
        if center_lps is not None:
239
            transform.SetCenter(center_lps)
240
        return transform
241
242
    @staticmethod
243
    def _get_rotation_transform(
244
            degrees: Sequence[float],
245
            translation: Sequence[float],
246
            center_lps: Optional[TypeTripletFloat] = None,
247
            ) -> sitk.Euler3DTransform:
248
        transform = sitk.Euler3DTransform()
249
        radians = np.radians(degrees)
250
        transform.SetRotation(*radians)
251
        transform.SetTranslation(translation)
252
        if center_lps is not None:
253
            transform.SetCenter(center_lps)
254
        return transform
255
256
    def get_affine_transform(self, image):
257
        scaling = np.array(self.scales).copy()
258
        rotation = np.array(self.degrees).copy()
259
        translation = np.array(self.translation).copy()
260
261
        if image.is_2d():
262
            scaling[2] = 1
263
            rotation[:-1] = 0
264
265
        if self.use_image_center:
266
            center_lps = image.get_center(lps=True)
267
        else:
268
            center_lps = None
269
270
        scaling_transform = self._get_scaling_transform(
271
            scaling,
272
            center_lps=center_lps,
273
        )
274
        rotation_transform = self._get_rotation_transform(
275
            rotation,
276
            translation,
277
            center_lps=center_lps,
278
        )
279
280
        sitk_major_version = get_major_sitk_version()
281
        if sitk_major_version == 1:
282
            transform = sitk.Transform(3, sitk.sitkComposite)
283
            transform.AddTransform(scaling_transform)
284
            transform.AddTransform(rotation_transform)
285
        elif sitk_major_version == 2:
286
            transforms = [scaling_transform, rotation_transform]
287
            transform = sitk.CompositeTransform(transforms)
288
289
        if self.invert_transform:
290
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
291
292
        return transform
293
294
    def apply_transform(self, subject: Subject) -> Subject:
295
        subject.check_consistent_spatial_shape()
296
        for image in self.get_images(subject):
297
            transform = self.get_affine_transform(image)
298
            transformed_tensors = []
299
            for tensor in image.data:
300
                sitk_image = nib_to_sitk(
301
                    tensor[np.newaxis],
302
                    image.affine,
303
                    force_3d=True,
304
                )
305
                if image[TYPE] != INTENSITY:
306
                    interpolation = 'nearest'
307
                    default_value = 0
308
                else:
309
                    interpolation = self.image_interpolation
310
                    if self.default_pad_value == 'minimum':
311
                        default_value = tensor.min().item()
312
                    elif self.default_pad_value == 'mean':
313
                        default_value = get_borders_mean(
314
                            sitk_image, filter_otsu=False)
315
                    elif self.default_pad_value == 'otsu':
316
                        default_value = get_borders_mean(
317
                            sitk_image, filter_otsu=True)
318
                    else:
319
                        default_value = self.default_pad_value
320
                transformed_tensor = self.apply_affine_transform(
321
                    sitk_image,
322
                    transform,
323
                    interpolation,
324
                    default_value,
325
                )
326
                transformed_tensors.append(transformed_tensor)
327
            image.set_data(torch.stack(transformed_tensors))
328
        return subject
329
330
    def apply_affine_transform(
331
            self,
332
            sitk_image: sitk.Image,
333
            transform: sitk.Transform,
334
            interpolation: str,
335
            default_value: float,
336
            center_lps: Optional[TypeTripletFloat] = None,
337
            ) -> torch.Tensor:
338
        floating = reference = sitk_image
339
340
        resampler = sitk.ResampleImageFilter()
341
        resampler.SetInterpolator(self.get_sitk_interpolator(interpolation))
342
        resampler.SetReferenceImage(reference)
343
        resampler.SetDefaultPixelValue(float(default_value))
344
        resampler.SetOutputPixelType(sitk.sitkFloat32)
345
        resampler.SetTransform(transform)
346
        resampled = resampler.Execute(floating)
347
348
        np_array = sitk.GetArrayFromImage(resampled)
349
        np_array = np_array.transpose()  # ITK to NumPy
350
        tensor = torch.as_tensor(np_array)
351
        return tensor
352
353
354
# flake8: noqa: E201, E203, E243
355
def get_borders_mean(image, filter_otsu=True):
356
    # pylint: disable=bad-whitespace
357
    array = sitk.GetArrayViewFromImage(image)
358
    borders_tuple = (
359
        array[ 0,  :,  :],
360
        array[-1,  :,  :],
361
        array[ :,  0,  :],
362
        array[ :, -1,  :],
363
        array[ :,  :,  0],
364
        array[ :,  :, -1],
365
    )
366
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
367
    if not filter_otsu:
368
        return borders_flat.mean()
369
    borders_reshaped = borders_flat.reshape(1, 1, -1)
370
    borders_image = sitk.GetImageFromArray(borders_reshaped)
371
    otsu = sitk.OtsuThresholdImageFilter()
372
    otsu.Execute(borders_image)
373
    threshold = otsu.GetThreshold()
374
    values = borders_flat[borders_flat < threshold]
375
    if values.any():
376
        default_value = values.mean()
377
    else:
378
        default_value = borders_flat.mean()
379
    return default_value
380
381
def _parse_scales_isotropic(scales, isotropic):
382
    scales = to_tuple(scales)
383
    if isotropic and len(scales) in (3, 6):
384
        message = (
385
            'If "isotropic" is True, the value for "scales" must have'
386
            f' length 1 or 2, but "{scales}" was passed'
387
        )
388
        raise ValueError(message)
389
390
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
391
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
392
        return value
393
    message = (
394
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
395
        ' or a number'
396
    )
397
    raise ValueError(message)
398