Passed
Pull Request — master (#353)
by Fernando
01:16
created

RandomAffine.apply_transform()   A

Complexity

Conditions 1

Size

Total Lines 19
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 17
nop 2
dl 0
loc 19
rs 9.55
c 0
b 0
f 0
1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, List, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import Interpolation, get_sitk_interpolator
22
from .. import RandomTransform
23
24
25
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
26
27
28
class RandomAffine(RandomTransform, SpatialTransform):
29
    r"""Random affine transformation.
30
31
    Args:
32
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
33
            scaling ranges.
34
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
35
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
36
            If two values :math:`(a, b)` are provided,
37
            then :math:`s_i \sim \mathcal{U}(a, b)`.
38
            If only one value :math:`x` is provided,
39
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
43
            making the objects inside look twice as small while preserving
44
            the physical size and position of the image bounds.
45
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
46
            rotation ranges in degrees.
47
            Rotation angles around each axis are
48
            :math:`(\theta_1, \theta_2, \theta_3)`,
49
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
50
            If two values :math:`(a, b)` are provided,
51
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
52
            If only one value :math:`x` is provided,
53
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
54
            If three values :math:`(x_1, x_2, x_3)` are provided,
55
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
56
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
57
            translation ranges in mm.
58
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
59
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
60
            If two values :math:`(a, b)` are provided,
61
            then :math:`t_i \sim \mathcal{U}(a, b)`.
62
            If only one value :math:`x` is provided,
63
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
64
            If three values :math:`(x_1, x_2, x_3)` are provided,
65
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
66
        isotropic: If ``True``, the scaling factor along all dimensions is the
67
            same, i.e. :math:`s_1 = s_2 = s_3`.
68
        center: If ``'image'``, rotations and scaling will be performed around
69
            the image center. If ``'origin'``, rotations and scaling will be
70
            performed around the origin in world coordinates.
71
        default_pad_value: As the image is rotated, some values near the
72
            borders will be undefined.
73
            If ``'minimum'``, the fill value will be the image minimum.
74
            If ``'mean'``, the fill value is the mean of the border values.
75
            If ``'otsu'``, the fill value is the mean of the values at the
76
            border that lie under an
77
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
78
            If it is a number, that value will be used.
79
        image_interpolation: See :ref:`Interpolation`.
80
        p: Probability that this transform will be applied.
81
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
82
        keys: See :py:class:`~torchio.transforms.Transform`.
83
84
    Example:
85
        >>> import torchio as tio
86
        >>> subject = tio.datasets.Colin27()
87
        >>> transform = tio.RandomAffine(
88
        ...     scales=(0.9, 1.2),
89
        ...     degrees=10,
90
        ...     isotropic=True,
91
        ...     image_interpolation='nearest',
92
        ... )
93
        >>> transformed = transform(subject)
94
95
    From the command line::
96
97
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "scales=(0.9, 1.2) degrees=10 isotropic=True image_interpolation=nearest" --seed 42 affine_min.nii.gz
98
99
    """
100
    def __init__(
101
            self,
102
            scales: TypeOneToSixFloat = 0.1,
103
            degrees: TypeOneToSixFloat = 10,
104
            translation: TypeOneToSixFloat = 0,
105
            isotropic: bool = False,
106
            center: str = 'image',
107
            default_pad_value: Union[str, float] = 'minimum',
108
            image_interpolation: str = 'linear',
109
            p: float = 1,
110
            keys: Optional[List[str]] = None,
111
            ):
112
        super().__init__(p=p, keys=keys)
113
        self.isotropic = isotropic
114
        _parse_scales_isotropic(scales, isotropic)
115
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
116
        self.degrees = self.parse_params(degrees, 0, 'degrees')
117
        self.translation = self.parse_params(translation, 0, 'translation')
118
        if center not in ('image', 'origin'):
119
            message = (
120
                'Center argument must be "image" or "origin",'
121
                f' not "{center}"'
122
            )
123
            raise ValueError(message)
124
        self.center = center
125
        self.default_pad_value = _parse_default_value(default_pad_value)
126
        self.image_interpolation = self.parse_interpolation(image_interpolation)
127
128
    def get_params(
129
            self,
130
            scales: TypeSextetFloat,
131
            degrees: TypeSextetFloat,
132
            translation: TypeSextetFloat,
133
            isotropic: bool,
134
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
135
        scaling_params = self.sample_uniform_sextet(scales)
136
        if isotropic:
137
            scaling_params.fill_(scaling_params[0])
138
        rotation_params = self.sample_uniform_sextet(degrees)
139
        translation_params = self.sample_uniform_sextet(translation)
140
        return scaling_params, rotation_params, translation_params
141
142
    def apply_transform(self, subject: Subject) -> Subject:
143
        subject.check_consistent_spatial_shape()
144
        scaling_params, rotation_params, translation_params = self.get_params(
145
            self.scales,
146
            self.degrees,
147
            self.translation,
148
            self.isotropic,
149
        )
150
        arguments = dict(
151
            scales=scaling_params.tolist(),
152
            degrees=rotation_params.tolist(),
153
            translation=translation_params.tolist(),
154
            center=self.center,
155
            default_pad_value=self.default_pad_value,
156
            image_interpolation=self.image_interpolation,
157
        )
158
        transform = Affine(**arguments)
159
        transformed = transform(subject)
160
        return transformed
161
162
163
class Affine(SpatialTransform):
164
    r"""Apply affine transformation.
165
166
    Args:
167
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
168
            scaling values along each dimension.
169
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
170
            rotation around each axis.
171
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
172
            translation in mm along each axis.
173
        center: If ``'image'``, rotations and scaling will be performed around
174
            the image center. If ``'origin'``, rotations and scaling will be
175
            performed around the origin in world coordinates.
176
        default_pad_value: As the image is rotated, some values near the
177
            borders will be undefined.
178
            If ``'minimum'``, the fill value will be the image minimum.
179
            If ``'mean'``, the fill value is the mean of the border values.
180
            If ``'otsu'``, the fill value is the mean of the values at the
181
            border that lie under an
182
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
183
            If it is a number, that value will be used.
184
        image_interpolation: See :ref:`Interpolation`.
185
        keys: See :py:class:`~torchio.transforms.Transform`.
186
    """
187
    def __init__(
188
            self,
189
            scales: TypeTripletFloat,
190
            degrees: TypeTripletFloat,
191
            translation: TypeTripletFloat,
192
            center: str = 'image',
193
            default_pad_value: Union[str, float] = 'minimum',
194
            image_interpolation: str = 'linear',
195
            keys: Optional[List[str]] = None,
196
            ):
197
        super().__init__(keys=keys)
198
        self.scales = self.parse_params(
199
            scales,
200
            None,
201
            'scales',
202
            make_ranges=False,
203
            min_constraint=0,
204
        )
205
        self.degrees = self.parse_params(
206
            degrees,
207
            None,
208
            'degrees',
209
            make_ranges=False,
210
        )
211
        self.translation = self.parse_params(
212
            translation,
213
            None,
214
            'translation',
215
            make_ranges=False,
216
        )
217
        if center not in ('image', 'origin'):
218
            message = (
219
                'Center argument must be "image" or "origin",'
220
                f' not "{center}"'
221
            )
222
            raise ValueError(message)
223
        self.center = center
224
        self.use_image_center = center == 'image'
225
        self.default_pad_value = _parse_default_value(default_pad_value)
226
        self.image_interpolation = self.parse_interpolation(image_interpolation)
227
        self.invert_transform = False
228
        self.args_names = (
229
            'scales',
230
            'degrees',
231
            'translation',
232
            'center',
233
            'default_pad_value',
234
            'image_interpolation',
235
        )
236
237
    @staticmethod
238
    def get_scaling_transform(
239
            scaling_params: List[float],
240
            center_lps: Optional[TypeTripletFloat] = None,
241
            ) -> sitk.ScaleTransform:
242
        # scaling_params are inverted so that they are more intuitive
243
        # For example, 1.5 means the objects look 1.5 times larger
244
        transform = sitk.ScaleTransform(3)
245
        scaling_params = 1 / np.array(scaling_params)
246
        transform.SetScale(scaling_params)
247
        if center_lps is not None:
248
            transform.SetCenter(center_lps)
249
        return transform
250
251
    @staticmethod
252
    def get_rotation_transform(
253
            degrees: List[float],
254
            translation: List[float],
255
            center_lps: Optional[TypeTripletFloat] = None,
256
            ) -> sitk.Euler3DTransform:
257
        transform = sitk.Euler3DTransform()
258
        radians = np.radians(degrees)
259
        transform.SetRotation(*radians)
260
        transform.SetTranslation(translation)
261
        if center_lps is not None:
262
            transform.SetCenter(center_lps)
263
        return transform
264
265
    def apply_transform(self, subject: Subject) -> Subject:
266
        scaling_params = np.array(self.scales).copy()
267
        rotation_params = np.array(self.degrees).copy()
268
        translation_params = np.array(self.translation).copy()
269
        subject.check_consistent_spatial_shape()
270
        for image in self.get_images(subject):
271
            if image[TYPE] != INTENSITY:
272
                interpolation = 'nearest'
273
            else:
274
                interpolation = self.image_interpolation
275
276
            if image.is_2d():
277
                scaling_params[2] = 1
278
                rotation_params[:-1] = 0
279
280
            if self.use_image_center:
281
                center = image.get_center(lps=True)
282
            else:
283
                center = None
284
285
            transformed_tensors = []
286
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
287
                transformed_tensor = self.apply_affine_transform(
288
                    tensor,
289
                    image[AFFINE],
290
                    scaling_params.tolist(),
291
                    rotation_params.tolist(),
292
                    translation_params.tolist(),
293
                    interpolation,
294
                    center_lps=center,
295
                )
296
                transformed_tensors.append(transformed_tensor)
297
            image[DATA] = torch.stack(transformed_tensors)
298
        return subject
299
300
    def apply_affine_transform(
301
            self,
302
            tensor: torch.Tensor,
303
            affine: np.ndarray,
304
            scaling_params: List[float],
305
            rotation_params: List[float],
306
            translation_params: List[float],
307
            interpolation: Interpolation,
308
            center_lps: Optional[TypeTripletFloat] = None,
309
            ) -> torch.Tensor:
310
        assert tensor.ndim == 3
311
312
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
313
        floating = reference = image
314
315
        scaling_transform = self.get_scaling_transform(
316
            scaling_params,
317
            center_lps=center_lps,
318
        )
319
        rotation_transform = self.get_rotation_transform(
320
            rotation_params,
321
            translation_params,
322
            center_lps=center_lps,
323
        )
324
325
        sitk_major_version = get_major_sitk_version()
326
        if sitk_major_version == 1:
327
            transform = sitk.Transform(3, sitk.sitkComposite)
328
            transform.AddTransform(scaling_transform)
329
            transform.AddTransform(rotation_transform)
330
        elif sitk_major_version == 2:
331
            transforms = [scaling_transform, rotation_transform]
332
            transform = sitk.CompositeTransform(transforms)
333
334
        if self.invert_transform:
335
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
336
337
        if self.default_pad_value == 'minimum':
338
            default_value = tensor.min().item()
339
        elif self.default_pad_value == 'mean':
340
            default_value = get_borders_mean(image, filter_otsu=False)
341
        elif self.default_pad_value == 'otsu':
342
            default_value = get_borders_mean(image, filter_otsu=True)
343
        else:
344
            default_value = self.default_pad_value
345
346
        resampler = sitk.ResampleImageFilter()
347
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
348
        resampler.SetReferenceImage(reference)
349
        resampler.SetDefaultPixelValue(float(default_value))
350
        resampler.SetOutputPixelType(sitk.sitkFloat32)
351
        resampler.SetTransform(transform)
352
        resampled = resampler.Execute(floating)
353
354
        np_array = sitk.GetArrayFromImage(resampled)
355
        np_array = np_array.transpose()  # ITK to NumPy
356
        tensor = torch.from_numpy(np_array)
357
        return tensor
358
359
360
# flake8: noqa: E201, E203, E243
361
def get_borders_mean(image, filter_otsu=True):
362
    # pylint: disable=bad-whitespace
363
    array = sitk.GetArrayViewFromImage(image)
364
    borders_tuple = (
365
        array[ 0,  :,  :],
366
        array[-1,  :,  :],
367
        array[ :,  0,  :],
368
        array[ :, -1,  :],
369
        array[ :,  :,  0],
370
        array[ :,  :, -1],
371
    )
372
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
373
    if not filter_otsu:
374
        return borders_flat.mean()
375
    borders_reshaped = borders_flat.reshape(1, 1, -1)
376
    borders_image = sitk.GetImageFromArray(borders_reshaped)
377
    otsu = sitk.OtsuThresholdImageFilter()
378
    otsu.Execute(borders_image)
379
    threshold = otsu.GetThreshold()
380
    values = borders_flat[borders_flat < threshold]
381
    if values.any():
382
        default_value = values.mean()
383
    else:
384
        default_value = borders_flat.mean()
385
    return default_value
386
387
def _parse_scales_isotropic(scales, isotropic):
388
    params = to_tuple(scales)
389
    if isotropic and len(scales) in (3, 6):
390
        message = (
391
            'If "isotropic" is True, the value for "scales" must have'
392
            f' length 1 or 2, but "{scales}" was passed'
393
        )
394
        raise ValueError(message)
395
396
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
397
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
398
        return value
399
    message = (
400
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
401
        ' or a number'
402
    )
403
    raise ValueError(message)
404