Passed
Pull Request — master (#353)
by Fernando
01:18
created

RandomAffine.get_params()   A

Complexity

Conditions 2

Size

Total Lines 13
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 12
nop 5
dl 0
loc 13
rs 9.8
c 0
b 0
f 0
1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, Sequence, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import RandomTransform
22
23
24
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
25
26
27
class RandomAffine(RandomTransform, SpatialTransform):
28
    r"""Random affine transformation.
29
30
    Args:
31
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
32
            scaling ranges.
33
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
34
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
35
            If two values :math:`(a, b)` are provided,
36
            then :math:`s_i \sim \mathcal{U}(a, b)`.
37
            If only one value :math:`x` is provided,
38
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
39
            If three values :math:`(x_1, x_2, x_3)` are provided,
40
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
41
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
42
            making the objects inside look twice as small while preserving
43
            the physical size and position of the image bounds.
44
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
45
            rotation ranges in degrees.
46
            Rotation angles around each axis are
47
            :math:`(\theta_1, \theta_2, \theta_3)`,
48
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
49
            If two values :math:`(a, b)` are provided,
50
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
51
            If only one value :math:`x` is provided,
52
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
53
            If three values :math:`(x_1, x_2, x_3)` are provided,
54
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
55
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
56
            translation ranges in mm.
57
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
58
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
59
            If two values :math:`(a, b)` are provided,
60
            then :math:`t_i \sim \mathcal{U}(a, b)`.
61
            If only one value :math:`x` is provided,
62
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
63
            If three values :math:`(x_1, x_2, x_3)` are provided,
64
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
65
        isotropic: If ``True``, the scaling factor along all dimensions is the
66
            same, i.e. :math:`s_1 = s_2 = s_3`.
67
        center: If ``'image'``, rotations and scaling will be performed around
68
            the image center. If ``'origin'``, rotations and scaling will be
69
            performed around the origin in world coordinates.
70
        default_pad_value: As the image is rotated, some values near the
71
            borders will be undefined.
72
            If ``'minimum'``, the fill value will be the image minimum.
73
            If ``'mean'``, the fill value is the mean of the border values.
74
            If ``'otsu'``, the fill value is the mean of the values at the
75
            border that lie under an
76
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
77
            If it is a number, that value will be used.
78
        image_interpolation: See :ref:`Interpolation`.
79
        p: Probability that this transform will be applied.
80
        keys: See :py:class:`~torchio.transforms.Transform`.
81
82
    Example:
83
        >>> import torchio as tio
84
        >>> subject = tio.datasets.Colin27()
85
        >>> transform = tio.RandomAffine(
86
        ...     scales=(0.9, 1.2),
87
        ...     degrees=10,
88
        ...     isotropic=True,
89
        ...     image_interpolation='nearest',
90
        ... )
91
        >>> transformed = transform(subject)
92
93
    From the command line::
94
95
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "scales=(0.9, 1.2) degrees=10 isotropic=True image_interpolation=nearest" --seed 42 affine_min.nii.gz
96
97
    """
98
    def __init__(
99
            self,
100
            scales: TypeOneToSixFloat = 0.1,
101
            degrees: TypeOneToSixFloat = 10,
102
            translation: TypeOneToSixFloat = 0,
103
            isotropic: bool = False,
104
            center: str = 'image',
105
            default_pad_value: Union[str, float] = 'minimum',
106
            image_interpolation: str = 'linear',
107
            p: float = 1,
108
            keys: Optional[Sequence[str]] = None,
109
            ):
110
        super().__init__(p=p, keys=keys)
111
        self.isotropic = isotropic
112
        _parse_scales_isotropic(scales, isotropic)
113
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
114
        self.degrees = self.parse_params(degrees, 0, 'degrees')
115
        self.translation = self.parse_params(translation, 0, 'translation')
116
        if center not in ('image', 'origin'):
117
            message = (
118
                'Center argument must be "image" or "origin",'
119
                f' not "{center}"'
120
            )
121
            raise ValueError(message)
122
        self.center = center
123
        self.default_pad_value = _parse_default_value(default_pad_value)
124
        self.image_interpolation = self.parse_interpolation(image_interpolation)
125
126
    def get_params(
127
            self,
128
            scales: TypeSextetFloat,
129
            degrees: TypeSextetFloat,
130
            translation: TypeSextetFloat,
131
            isotropic: bool,
132
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
133
        scaling_params = self.sample_uniform_sextet(scales)
134
        if isotropic:
135
            scaling_params.fill_(scaling_params[0])
136
        rotation_params = self.sample_uniform_sextet(degrees)
137
        translation_params = self.sample_uniform_sextet(translation)
138
        return scaling_params, rotation_params, translation_params
139
140
    def apply_transform(self, subject: Subject) -> Subject:
141
        subject.check_consistent_spatial_shape()
142
        scaling_params, rotation_params, translation_params = self.get_params(
143
            self.scales,
144
            self.degrees,
145
            self.translation,
146
            self.isotropic,
147
        )
148
        arguments = dict(
149
            scales=scaling_params.tolist(),
150
            degrees=rotation_params.tolist(),
151
            translation=translation_params.tolist(),
152
            center=self.center,
153
            default_pad_value=self.default_pad_value,
154
            image_interpolation=self.image_interpolation,
155
        )
156
        transform = Affine(**arguments)
157
        transformed = transform(subject)
158
        return transformed
159
160
161
class Affine(SpatialTransform):
162
    r"""Apply affine transformation.
163
164
    Args:
165
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
166
            scaling values along each dimension.
167
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
168
            rotation around each axis.
169
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
170
            translation in mm along each axis.
171
        center: If ``'image'``, rotations and scaling will be performed around
172
            the image center. If ``'origin'``, rotations and scaling will be
173
            performed around the origin in world coordinates.
174
        default_pad_value: As the image is rotated, some values near the
175
            borders will be undefined.
176
            If ``'minimum'``, the fill value will be the image minimum.
177
            If ``'mean'``, the fill value is the mean of the border values.
178
            If ``'otsu'``, the fill value is the mean of the values at the
179
            border that lie under an
180
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
181
            If it is a number, that value will be used.
182
        image_interpolation: See :ref:`Interpolation`.
183
        keys: See :py:class:`~torchio.transforms.Transform`.
184
    """
185
    def __init__(
186
            self,
187
            scales: TypeTripletFloat,
188
            degrees: TypeTripletFloat,
189
            translation: TypeTripletFloat,
190
            center: str = 'image',
191
            default_pad_value: Union[str, float] = 'minimum',
192
            image_interpolation: str = 'linear',
193
            keys: Optional[Sequence[str]] = None,
194
            ):
195
        super().__init__(keys=keys)
196
        self.scales = self.parse_params(
197
            scales,
198
            None,
199
            'scales',
200
            make_ranges=False,
201
            min_constraint=0,
202
        )
203
        self.degrees = self.parse_params(
204
            degrees,
205
            None,
206
            'degrees',
207
            make_ranges=False,
208
        )
209
        self.translation = self.parse_params(
210
            translation,
211
            None,
212
            'translation',
213
            make_ranges=False,
214
        )
215
        if center not in ('image', 'origin'):
216
            message = (
217
                'Center argument must be "image" or "origin",'
218
                f' not "{center}"'
219
            )
220
            raise ValueError(message)
221
        self.center = center
222
        self.use_image_center = center == 'image'
223
        self.default_pad_value = _parse_default_value(default_pad_value)
224
        self.image_interpolation = self.parse_interpolation(image_interpolation)
225
        self.invert_transform = False
226
        self.args_names = (
227
            'scales',
228
            'degrees',
229
            'translation',
230
            'center',
231
            'default_pad_value',
232
            'image_interpolation',
233
        )
234
235
    @staticmethod
236
    def get_scaling_transform(
237
            scaling_params: Sequence[float],
238
            center_lps: Optional[TypeTripletFloat] = None,
239
            ) -> sitk.ScaleTransform:
240
        # scaling_params are inverted so that they are more intuitive
241
        # For example, 1.5 means the objects look 1.5 times larger
242
        transform = sitk.ScaleTransform(3)
243
        scaling_params = 1 / np.array(scaling_params)
244
        transform.SetScale(scaling_params)
245
        if center_lps is not None:
246
            transform.SetCenter(center_lps)
247
        return transform
248
249
    @staticmethod
250
    def get_rotation_transform(
251
            degrees: Sequence[float],
252
            translation: Sequence[float],
253
            center_lps: Optional[TypeTripletFloat] = None,
254
            ) -> sitk.Euler3DTransform:
255
        transform = sitk.Euler3DTransform()
256
        radians = np.radians(degrees)
257
        transform.SetRotation(*radians)
258
        transform.SetTranslation(translation)
259
        if center_lps is not None:
260
            transform.SetCenter(center_lps)
261
        return transform
262
263
    def apply_transform(self, subject: Subject) -> Subject:
264
        scaling_params = np.array(self.scales).copy()
265
        rotation_params = np.array(self.degrees).copy()
266
        translation_params = np.array(self.translation).copy()
267
        subject.check_consistent_spatial_shape()
268
        for image in self.get_images(subject):
269
            if image[TYPE] != INTENSITY:
270
                interpolation = 'nearest'
271
            else:
272
                interpolation = self.image_interpolation
273
274
            if image.is_2d():
275
                scaling_params[2] = 1
276
                rotation_params[:-1] = 0
277
278
            if self.use_image_center:
279
                center = image.get_center(lps=True)
280
            else:
281
                center = None
282
283
            transformed_tensors = []
284
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
285
                transformed_tensor = self.apply_affine_transform(
286
                    tensor,
287
                    image[AFFINE],
288
                    scaling_params.tolist(),
289
                    rotation_params.tolist(),
290
                    translation_params.tolist(),
291
                    interpolation,
292
                    center_lps=center,
293
                )
294
                transformed_tensors.append(transformed_tensor)
295
            image[DATA] = torch.stack(transformed_tensors)
296
        return subject
297
298
    def apply_affine_transform(
299
            self,
300
            tensor: torch.Tensor,
301
            affine: np.ndarray,
302
            scaling_params: Sequence[float],
303
            rotation_params: Sequence[float],
304
            translation_params: Sequence[float],
305
            interpolation: str,
306
            center_lps: Optional[TypeTripletFloat] = None,
307
            ) -> torch.Tensor:
308
        assert tensor.ndim == 3
309
310
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
311
        floating = reference = image
312
313
        scaling_transform = self.get_scaling_transform(
314
            scaling_params,
315
            center_lps=center_lps,
316
        )
317
        rotation_transform = self.get_rotation_transform(
318
            rotation_params,
319
            translation_params,
320
            center_lps=center_lps,
321
        )
322
323
        sitk_major_version = get_major_sitk_version()
324
        if sitk_major_version == 1:
325
            transform = sitk.Transform(3, sitk.sitkComposite)
326
            transform.AddTransform(scaling_transform)
327
            transform.AddTransform(rotation_transform)
328
        elif sitk_major_version == 2:
329
            transforms = [scaling_transform, rotation_transform]
330
            transform = sitk.CompositeTransform(transforms)
331
332
        if self.invert_transform:
333
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
334
335
        if self.default_pad_value == 'minimum':
336
            default_value = tensor.min().item()
337
        elif self.default_pad_value == 'mean':
338
            default_value = get_borders_mean(image, filter_otsu=False)
339
        elif self.default_pad_value == 'otsu':
340
            default_value = get_borders_mean(image, filter_otsu=True)
341
        else:
342
            default_value = self.default_pad_value
343
344
        resampler = sitk.ResampleImageFilter()
345
        resampler.SetInterpolator(self.get_sitk_interpolator(interpolation))
346
        resampler.SetReferenceImage(reference)
347
        resampler.SetDefaultPixelValue(float(default_value))
348
        resampler.SetOutputPixelType(sitk.sitkFloat32)
349
        resampler.SetTransform(transform)
350
        resampled = resampler.Execute(floating)
351
352
        np_array = sitk.GetArrayFromImage(resampled)
353
        np_array = np_array.transpose()  # ITK to NumPy
354
        tensor = torch.from_numpy(np_array)
355
        return tensor
356
357
358
# flake8: noqa: E201, E203, E243
359
def get_borders_mean(image, filter_otsu=True):
360
    # pylint: disable=bad-whitespace
361
    array = sitk.GetArrayViewFromImage(image)
362
    borders_tuple = (
363
        array[ 0,  :,  :],
364
        array[-1,  :,  :],
365
        array[ :,  0,  :],
366
        array[ :, -1,  :],
367
        array[ :,  :,  0],
368
        array[ :,  :, -1],
369
    )
370
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
371
    if not filter_otsu:
372
        return borders_flat.mean()
373
    borders_reshaped = borders_flat.reshape(1, 1, -1)
374
    borders_image = sitk.GetImageFromArray(borders_reshaped)
375
    otsu = sitk.OtsuThresholdImageFilter()
376
    otsu.Execute(borders_image)
377
    threshold = otsu.GetThreshold()
378
    values = borders_flat[borders_flat < threshold]
379
    if values.any():
380
        default_value = values.mean()
381
    else:
382
        default_value = borders_flat.mean()
383
    return default_value
384
385
def _parse_scales_isotropic(scales, isotropic):
386
    params = to_tuple(scales)
387
    if isotropic and len(scales) in (3, 6):
388
        message = (
389
            'If "isotropic" is True, the value for "scales" must have'
390
            f' length 1 or 2, but "{scales}" was passed'
391
        )
392
        raise ValueError(message)
393
394
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
395
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
396
        return value
397
    message = (
398
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
399
        ' or a number'
400
    )
401
    raise ValueError(message)
402