Passed
Pull Request — master (#353)
by Fernando
01:07
created

torchio.transforms.augmentation.spatial.random_affine   A

Complexity

Total Complexity 34

Size/Duplication

Total Lines 407
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 252
dl 0
loc 407
rs 9.68
c 0
b 0
f 0
wmc 34

9 Methods

Rating   Name   Duplication   Size   Complexity  
A Affine.get_rotation_transform() 0 13 2
A RandomAffine.apply_transform() 0 19 1
A Affine.get_arguments() 0 10 1
A Affine.__init__() 0 41 2
A RandomAffine.get_params() 0 13 2
A Affine.get_scaling_transform() 0 13 2
A RandomAffine.__init__() 0 27 2
B Affine.apply_affine_transform() 0 58 7
B Affine.apply_transform() 0 34 6

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _parse_scales_isotropic() 0 8 3
A _parse_default_value() 0 8 3
A get_borders_mean() 0 25 3
1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, List, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import Interpolation, get_sitk_interpolator
22
from .. import RandomTransform
23
24
25
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
26
27
28
class RandomAffine(RandomTransform, SpatialTransform):
29
    r"""Random affine transformation.
30
31
    Args:
32
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
33
            scaling ranges.
34
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
35
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
36
            If two values :math:`(a, b)` are provided,
37
            then :math:`s_i \sim \mathcal{U}(a, b)`.
38
            If only one value :math:`x` is provided,
39
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
43
            making the objects inside look twice as small while preserving
44
            the physical size and position of the image bounds.
45
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
46
            rotation ranges in degrees.
47
            Rotation angles around each axis are
48
            :math:`(\theta_1, \theta_2, \theta_3)`,
49
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
50
            If two values :math:`(a, b)` are provided,
51
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
52
            If only one value :math:`x` is provided,
53
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
54
            If three values :math:`(x_1, x_2, x_3)` are provided,
55
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
56
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
57
            translation ranges in mm.
58
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
59
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
60
            If two values :math:`(a, b)` are provided,
61
            then :math:`t_i \sim \mathcal{U}(a, b)`.
62
            If only one value :math:`x` is provided,
63
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
64
            If three values :math:`(x_1, x_2, x_3)` are provided,
65
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
66
        isotropic: If ``True``, the scaling factor along all dimensions is the
67
            same, i.e. :math:`s_1 = s_2 = s_3`.
68
        center: If ``'image'``, rotations and scaling will be performed around
69
            the image center. If ``'origin'``, rotations and scaling will be
70
            performed around the origin in world coordinates.
71
        default_pad_value: As the image is rotated, some values near the
72
            borders will be undefined.
73
            If ``'minimum'``, the fill value will be the image minimum.
74
            If ``'mean'``, the fill value is the mean of the border values.
75
            If ``'otsu'``, the fill value is the mean of the values at the
76
            border that lie under an
77
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
78
            If it is a number, that value will be used.
79
        image_interpolation: See :ref:`Interpolation`.
80
        p: Probability that this transform will be applied.
81
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
82
        keys: See :py:class:`~torchio.transforms.Transform`.
83
84
    Example:
85
        >>> import torchio as tio
86
        >>> subject = tio.datasets.Colin27()
87
        >>> transform = tio.RandomAffine(
88
        ...     scales=(0.9, 1.2),
89
        ...     degrees=10,
90
        ...     isotropic=True,
91
        ...     image_interpolation='nearest',
92
        ... )
93
        >>> transformed = transform(subject)
94
95
    From the command line::
96
97
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "scales=(0.9, 1.2) degrees=10 isotropic=True image_interpolation=nearest" --seed 42 affine_min.nii.gz
98
99
    """
100
    def __init__(
101
            self,
102
            scales: TypeOneToSixFloat = 0.1,
103
            degrees: TypeOneToSixFloat = 10,
104
            translation: TypeOneToSixFloat = 0,
105
            isotropic: bool = False,
106
            center: str = 'image',
107
            default_pad_value: Union[str, float] = 'otsu',
108
            image_interpolation: str = 'linear',
109
            p: float = 1,
110
            keys: Optional[List[str]] = None,
111
            ):
112
        super().__init__(p=p, keys=keys)
113
        self.isotropic = isotropic
114
        _parse_scales_isotropic(scales, isotropic)
115
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
116
        self.degrees = self.parse_params(degrees, 0, 'degrees')
117
        self.translation = self.parse_params(translation, 0, 'translation')
118
        if center not in ('image', 'origin'):
119
            message = (
120
                'Center argument must be "image" or "origin",'
121
                f' not "{center}"'
122
            )
123
            raise ValueError(message)
124
        self.center = center
125
        self.default_pad_value = _parse_default_value(default_pad_value)
126
        self.interpolation = self.parse_interpolation(image_interpolation)
127
128
    def get_params(
129
            self,
130
            scales: TypeSextetFloat,
131
            degrees: TypeSextetFloat,
132
            translation: TypeSextetFloat,
133
            isotropic: bool,
134
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
135
        scaling_params = self.sample_uniform_sextet(scales)
136
        if isotropic:
137
            scaling_params.fill_(scaling_params[0])
138
        rotation_params = self.sample_uniform_sextet(degrees)
139
        translation_params = self.sample_uniform_sextet(translation)
140
        return scaling_params, rotation_params, translation_params
141
142
    def apply_transform(self, subject: Subject) -> Subject:
143
        subject.check_consistent_spatial_shape()
144
        scaling_params, rotation_params, translation_params = self.get_params(
145
            self.scales,
146
            self.degrees,
147
            self.translation,
148
            self.isotropic,
149
        )
150
        arguments = dict(
151
            scales=scaling_params.tolist(),
152
            degrees=rotation_params.tolist(),
153
            translation=translation_params.tolist(),
154
            center=self.center,
155
            default_pad_value=self.default_pad_value,
156
            image_interpolation=self.interpolation,
157
        )
158
        transform = Affine(**arguments)
159
        transformed = transform(subject)
160
        return transformed
161
162
163
class Affine(SpatialTransform):
164
    r"""Apply affine transformation.
165
166
    Args:
167
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
168
            scaling values along each dimension.
169
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
170
            rotation around each axis.
171
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
172
            translation in mm along each axis.
173
        center: If ``'image'``, rotations and scaling will be performed around
174
            the image center. If ``'origin'``, rotations and scaling will be
175
            performed around the origin in world coordinates.
176
        default_pad_value: As the image is rotated, some values near the
177
            borders will be undefined.
178
            If ``'minimum'``, the fill value will be the image minimum.
179
            If ``'mean'``, the fill value is the mean of the border values.
180
            If ``'otsu'``, the fill value is the mean of the values at the
181
            border that lie under an
182
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
183
            If it is a number, that value will be used.
184
        image_interpolation: See :ref:`Interpolation`.
185
        keys: See :py:class:`~torchio.transforms.Transform`.
186
    """
187
    def __init__(
188
            self,
189
            scales: TypeTripletFloat,
190
            degrees: TypeTripletFloat,
191
            translation: TypeTripletFloat,
192
            center: str = 'image',
193
            default_pad_value: Union[str, float] = 'otsu',
194
            image_interpolation: str = 'linear',
195
            keys: Optional[List[str]] = None,
196
            ):
197
        super().__init__(keys=keys)
198
        self.scales = self.parse_params(
199
            scales,
200
            None,
201
            'scales',
202
            make_ranges=False,
203
            min_constraint=0,
204
        )
205
        self.degrees = self.parse_params(
206
            degrees,
207
            None,
208
            'degrees',
209
            make_ranges=False,
210
        )
211
        self.translation = self.parse_params(
212
            translation,
213
            None,
214
            'translation',
215
            make_ranges=False,
216
        )
217
        if center not in ('image', 'origin'):
218
            message = (
219
                'Center argument must be "image" or "origin",'
220
                f' not "{center}"'
221
            )
222
            raise ValueError(message)
223
        self.center = center
224
        self.use_image_center = center == 'image'
225
        self.default_pad_value = _parse_default_value(default_pad_value)
226
        self.interpolation = self.parse_interpolation(image_interpolation)
227
        self.invert_transform = False
228
229
    def get_arguments(self):
230
        arguments = dict(
231
            scales=self.scales,
232
            degrees=self.degrees,
233
            translation=self.translation,
234
            center=self.center,
235
            default_pad_value=self.default_pad_value,
236
            image_interpolation=self.interpolation,
237
        )
238
        return arguments
239
240
    @staticmethod
241
    def get_scaling_transform(
242
            scaling_params: List[float],
243
            center_lps: Optional[TypeTripletFloat] = None,
244
            ) -> sitk.ScaleTransform:
245
        # scaling_params are inverted so that they are more intuitive
246
        # For example, 1.5 means the objects look 1.5 times larger
247
        transform = sitk.ScaleTransform(3)
248
        scaling_params = 1 / np.array(scaling_params)
249
        transform.SetScale(scaling_params)
250
        if center_lps is not None:
251
            transform.SetCenter(center_lps)
252
        return transform
253
254
    @staticmethod
255
    def get_rotation_transform(
256
            degrees: List[float],
257
            translation: List[float],
258
            center_lps: Optional[TypeTripletFloat] = None,
259
            ) -> sitk.Euler3DTransform:
260
        transform = sitk.Euler3DTransform()
261
        radians = np.radians(degrees)
262
        transform.SetRotation(*radians)
263
        transform.SetTranslation(translation)
264
        if center_lps is not None:
265
            transform.SetCenter(center_lps)
266
        return transform
267
268
    def apply_transform(self, subject: Subject) -> Subject:
269
        scaling_params = np.array(self.scales).copy()
270
        rotation_params = np.array(self.degrees).copy()
271
        translation_params = np.array(self.translation).copy()
272
        subject.check_consistent_spatial_shape()
273
        for image in self.get_images(subject):
274
            if image[TYPE] != INTENSITY:
275
                interpolation = 'nearest'
276
            else:
277
                interpolation = self.interpolation
278
279
            if image.is_2d():
280
                scaling_params[2] = 1
281
                rotation_params[:-1] = 0
282
283
            if self.use_image_center:
284
                center = image.get_center(lps=True)
285
            else:
286
                center = None
287
288
            transformed_tensors = []
289
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
290
                transformed_tensor = self.apply_affine_transform(
291
                    tensor,
292
                    image[AFFINE],
293
                    scaling_params.tolist(),
294
                    rotation_params.tolist(),
295
                    translation_params.tolist(),
296
                    interpolation,
297
                    center_lps=center,
298
                )
299
                transformed_tensors.append(transformed_tensor)
300
            image[DATA] = torch.stack(transformed_tensors)
301
        return subject
302
303
    def apply_affine_transform(
304
            self,
305
            tensor: torch.Tensor,
306
            affine: np.ndarray,
307
            scaling_params: List[float],
308
            rotation_params: List[float],
309
            translation_params: List[float],
310
            interpolation: Interpolation,
311
            center_lps: Optional[TypeTripletFloat] = None,
312
            ) -> torch.Tensor:
313
        assert tensor.ndim == 3
314
315
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
316
        floating = reference = image
317
318
        scaling_transform = self.get_scaling_transform(
319
            scaling_params,
320
            center_lps=center_lps,
321
        )
322
        rotation_transform = self.get_rotation_transform(
323
            rotation_params,
324
            translation_params,
325
            center_lps=center_lps,
326
        )
327
328
        sitk_major_version = get_major_sitk_version()
329
        if sitk_major_version == 1:
330
            transform = sitk.Transform(3, sitk.sitkComposite)
331
            transform.AddTransform(scaling_transform)
332
            transform.AddTransform(rotation_transform)
333
        elif sitk_major_version == 2:
334
            transforms = [scaling_transform, rotation_transform]
335
            transform = sitk.CompositeTransform(transforms)
336
337
        if self.invert_transform:
338
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
339
340
        if self.default_pad_value == 'minimum':
341
            default_value = tensor.min().item()
342
        elif self.default_pad_value == 'mean':
343
            default_value = get_borders_mean(image, filter_otsu=False)
344
        elif self.default_pad_value == 'otsu':
345
            default_value = get_borders_mean(image, filter_otsu=True)
346
        else:
347
            default_value = self.default_pad_value
348
349
        resampler = sitk.ResampleImageFilter()
350
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
351
        resampler.SetReferenceImage(reference)
352
        resampler.SetDefaultPixelValue(float(default_value))
353
        resampler.SetOutputPixelType(sitk.sitkFloat32)
354
        resampler.SetTransform(transform)
355
        resampled = resampler.Execute(floating)
356
357
        np_array = sitk.GetArrayFromImage(resampled)
358
        np_array = np_array.transpose()  # ITK to NumPy
359
        tensor = torch.from_numpy(np_array)
360
        return tensor
361
362
363
# flake8: noqa: E201, E203, E243
364
def get_borders_mean(image, filter_otsu=True):
365
    # pylint: disable=bad-whitespace
366
    array = sitk.GetArrayViewFromImage(image)
367
    borders_tuple = (
368
        array[ 0,  :,  :],
369
        array[-1,  :,  :],
370
        array[ :,  0,  :],
371
        array[ :, -1,  :],
372
        array[ :,  :,  0],
373
        array[ :,  :, -1],
374
    )
375
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
376
    if not filter_otsu:
377
        return borders_flat.mean()
378
    borders_reshaped = borders_flat.reshape(1, 1, -1)
379
    borders_image = sitk.GetImageFromArray(borders_reshaped)
380
    otsu = sitk.OtsuThresholdImageFilter()
381
    otsu.Execute(borders_image)
382
    threshold = otsu.GetThreshold()
383
    values = borders_flat[borders_flat < threshold]
384
    if values.any():
385
        default_value = values.mean()
386
    else:
387
        default_value = borders_flat.mean()
388
    return default_value
389
390
def _parse_scales_isotropic(scales, isotropic):
391
    params = to_tuple(scales)
392
    if isotropic and len(scales) in (3, 6):
393
        message = (
394
            'If "isotropic" is True, the value for "scales" must have'
395
            f' length 1 or 2, but "{scales}" was passed'
396
        )
397
        raise ValueError(message)
398
399
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
400
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
401
        return value
402
    message = (
403
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
404
        ' or a number'
405
    )
406
    raise ValueError(message)
407