Passed
Push — master ( e28dfa...5bce14 )
by Fernando
04:12
created

Affine._get_rotation_transform()   A

Complexity

Conditions 2

Size

Total Lines 13
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 13
rs 9.8
c 0
b 0
f 0
cc 2
nop 3
1
from numbers import Number
2
from typing import Tuple, Optional, Sequence, Union
3
4
import torch
5
import numpy as np
6
import SimpleITK as sitk
7
8
from ....data.io import nib_to_sitk
9
from ....data.subject import Subject
10
from ....constants import INTENSITY, TYPE
11
from ....utils import get_major_sitk_version, to_tuple
12
from ....typing import TypeRangeFloat, TypeSextetFloat, TypeTripletFloat
13
from ... import SpatialTransform
14
from .. import RandomTransform
15
16
17
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
18
19
20
class RandomAffine(RandomTransform, SpatialTransform):
21
    r"""Apply a random affine transformation and resample the image.
22
23
    Args:
24
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
25
            scaling ranges.
26
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
27
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
28
            If two values :math:`(a, b)` are provided,
29
            then :math:`s_i \sim \mathcal{U}(a, b)`.
30
            If only one value :math:`x` is provided,
31
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
32
            If three values :math:`(x_1, x_2, x_3)` are provided,
33
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
34
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
35
            making the objects inside look twice as small while preserving
36
            the physical size and position of the image bounds.
37
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
38
            rotation ranges in degrees.
39
            Rotation angles around each axis are
40
            :math:`(\theta_1, \theta_2, \theta_3)`,
41
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
42
            If two values :math:`(a, b)` are provided,
43
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
44
            If only one value :math:`x` is provided,
45
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
46
            If three values :math:`(x_1, x_2, x_3)` are provided,
47
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
48
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
49
            translation ranges in mm.
50
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
51
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
52
            If two values :math:`(a, b)` are provided,
53
            then :math:`t_i \sim \mathcal{U}(a, b)`.
54
            If only one value :math:`x` is provided,
55
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
56
            If three values :math:`(x_1, x_2, x_3)` are provided,
57
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
58
        isotropic: If ``True``, the scaling factor along all dimensions is the
59
            same, i.e. :math:`s_1 = s_2 = s_3`.
60
        center: If ``'image'``, rotations and scaling will be performed around
61
            the image center. If ``'origin'``, rotations and scaling will be
62
            performed around the origin in world coordinates.
63
        default_pad_value: As the image is rotated, some values near the
64
            borders will be undefined.
65
            If ``'minimum'``, the fill value will be the image minimum.
66
            If ``'mean'``, the fill value is the mean of the border values.
67
            If ``'otsu'``, the fill value is the mean of the values at the
68
            border that lie under an
69
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
70
            If it is a number, that value will be used.
71
        image_interpolation: See :ref:`Interpolation`.
72
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
73
            keyword arguments.
74
75
    Example:
76
        >>> import torchio as tio
77
        >>> subject = tio.datasets.Colin27()
78
        >>> transform = tio.RandomAffine(
79
        ...     scales=(0.9, 1.2),
80
        ...     degrees=10,
81
        ...     isotropic=True,
82
        ...     image_interpolation='nearest',
83
        ... )
84
        >>> transformed = transform(subject)
85
86
    From the command line::
87
88
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "scales=(0.9, 1.2) degrees=10 isotropic=True image_interpolation=nearest" --seed 42 affine_min.nii.gz
89
90
    """
91
    def __init__(
92
            self,
93
            scales: TypeOneToSixFloat = 0.1,
94
            degrees: TypeOneToSixFloat = 10,
95
            translation: TypeOneToSixFloat = 0,
96
            isotropic: bool = False,
97
            center: str = 'image',
98
            default_pad_value: Union[str, float] = 'minimum',
99
            image_interpolation: str = 'linear',
100
            **kwargs
101
            ):
102
        super().__init__(**kwargs)
103
        self.isotropic = isotropic
104
        _parse_scales_isotropic(scales, isotropic)
105
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
106
        self.degrees = self.parse_params(degrees, 0, 'degrees')
107
        self.translation = self.parse_params(translation, 0, 'translation')
108
        if center not in ('image', 'origin'):
109
            message = (
110
                'Center argument must be "image" or "origin",'
111
                f' not "{center}"'
112
            )
113
            raise ValueError(message)
114
        self.center = center
115
        self.default_pad_value = _parse_default_value(default_pad_value)
116
        self.image_interpolation = self.parse_interpolation(image_interpolation)
117
118
    def get_params(
119
            self,
120
            scales: TypeSextetFloat,
121
            degrees: TypeSextetFloat,
122
            translation: TypeSextetFloat,
123
            isotropic: bool,
124
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
125
        scaling_params = self.sample_uniform_sextet(scales)
126
        if isotropic:
127
            scaling_params.fill_(scaling_params[0])
128
        rotation_params = self.sample_uniform_sextet(degrees)
129
        translation_params = self.sample_uniform_sextet(translation)
130
        return scaling_params, rotation_params, translation_params
131
132
    def apply_transform(self, subject: Subject) -> Subject:
133
        subject.check_consistent_spatial_shape()
134
        scaling_params, rotation_params, translation_params = self.get_params(
135
            self.scales,
136
            self.degrees,
137
            self.translation,
138
            self.isotropic,
139
        )
140
        arguments = dict(
141
            scales=scaling_params.tolist(),
142
            degrees=rotation_params.tolist(),
143
            translation=translation_params.tolist(),
144
            center=self.center,
145
            default_pad_value=self.default_pad_value,
146
            image_interpolation=self.image_interpolation,
147
        )
148
        transform = Affine(**self.add_include_exclude(arguments))
149
        transformed = transform(subject)
150
        return transformed
151
152
153
class Affine(SpatialTransform):
154
    r"""Apply affine transformation.
155
156
    Args:
157
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
158
            scaling values along each dimension.
159
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
160
            rotation around each axis.
161
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
162
            translation in mm along each axis.
163
        center: If ``'image'``, rotations and scaling will be performed around
164
            the image center. If ``'origin'``, rotations and scaling will be
165
            performed around the origin in world coordinates.
166
        default_pad_value: As the image is rotated, some values near the
167
            borders will be undefined.
168
            If ``'minimum'``, the fill value will be the image minimum.
169
            If ``'mean'``, the fill value is the mean of the border values.
170
            If ``'otsu'``, the fill value is the mean of the values at the
171
            border that lie under an
172
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
173
            If it is a number, that value will be used.
174
        image_interpolation: See :ref:`Interpolation`.
175
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
176
            keyword arguments.
177
    """
178
    def __init__(
179
            self,
180
            scales: TypeTripletFloat,
181
            degrees: TypeTripletFloat,
182
            translation: TypeTripletFloat,
183
            center: str = 'image',
184
            default_pad_value: Union[str, float] = 'minimum',
185
            image_interpolation: str = 'linear',
186
            **kwargs
187
            ):
188
        super().__init__(**kwargs)
189
        self.scales = self.parse_params(
190
            scales,
191
            None,
192
            'scales',
193
            make_ranges=False,
194
            min_constraint=0,
195
        )
196
        self.degrees = self.parse_params(
197
            degrees,
198
            None,
199
            'degrees',
200
            make_ranges=False,
201
        )
202
        self.translation = self.parse_params(
203
            translation,
204
            None,
205
            'translation',
206
            make_ranges=False,
207
        )
208
        if center not in ('image', 'origin'):
209
            message = (
210
                'Center argument must be "image" or "origin",'
211
                f' not "{center}"'
212
            )
213
            raise ValueError(message)
214
        self.center = center
215
        self.use_image_center = center == 'image'
216
        self.default_pad_value = _parse_default_value(default_pad_value)
217
        self.image_interpolation = self.parse_interpolation(image_interpolation)
218
        self.invert_transform = False
219
        self.args_names = (
220
            'scales',
221
            'degrees',
222
            'translation',
223
            'center',
224
            'default_pad_value',
225
            'image_interpolation',
226
        )
227
228
    @staticmethod
229
    def _get_scaling_transform(
230
            scaling_params: Sequence[float],
231
            center_lps: Optional[TypeTripletFloat] = None,
232
            ) -> sitk.ScaleTransform:
233
        # scaling_params are inverted so that they are more intuitive
234
        # For example, 1.5 means the objects look 1.5 times larger
235
        transform = sitk.ScaleTransform(3)
236
        scaling_params = 1 / np.array(scaling_params)
237
        transform.SetScale(scaling_params)
238
        if center_lps is not None:
239
            transform.SetCenter(center_lps)
240
        return transform
241
242
    @staticmethod
243
    def _get_rotation_transform(
244
            degrees: Sequence[float],
245
            translation: Sequence[float],
246
            center_lps: Optional[TypeTripletFloat] = None,
247
            ) -> sitk.Euler3DTransform:
248
        transform = sitk.Euler3DTransform()
249
        radians = np.radians(degrees)
250
        transform.SetRotation(*radians)
251
        transform.SetTranslation(translation)
252
        if center_lps is not None:
253
            transform.SetCenter(center_lps)
254
        return transform
255
256
    def get_affine_transform(self, image):
257
        scaling = np.array(self.scales).copy()
258
        rotation = np.array(self.degrees).copy()
259
        translation = np.array(self.translation).copy()
260
261
        if image.is_2d():
262
            scaling[2] = 1
263
            rotation[:-1] = 0
264
265
        if self.use_image_center:
266
            center_lps = image.get_center(lps=True)
267
        else:
268
            center_lps = None
269
270
        scaling_transform = self._get_scaling_transform(
271
            scaling,
272
            center_lps=center_lps,
273
        )
274
        rotation_transform = self._get_rotation_transform(
275
            rotation,
276
            translation,
277
            center_lps=center_lps,
278
        )
279
280
        sitk_major_version = get_major_sitk_version()
281
        if sitk_major_version == 1:
282
            transform = sitk.Transform(3, sitk.sitkComposite)
283
            transform.AddTransform(scaling_transform)
284
            transform.AddTransform(rotation_transform)
285
        elif sitk_major_version == 2:
286
            transforms = [scaling_transform, rotation_transform]
287
            transform = sitk.CompositeTransform(transforms)
288
289
        if self.invert_transform:
290
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
291
292
        return transform
293
294
    def apply_transform(self, subject: Subject) -> Subject:
295
        subject.check_consistent_spatial_shape()
296
        for image in self.get_images(subject):
297
            transform = self.get_affine_transform(image)
298
            transformed_tensors = []
299
            for tensor in image.data:
300
                sitk_image = nib_to_sitk(
301
                    tensor[np.newaxis],
302
                    image.affine,
303
                    force_3d=True,
304
                )
305
                if image[TYPE] != INTENSITY:
306
                    interpolation = 'nearest'
307
                    default_value = 0
308
                else:
309
                    interpolation = self.image_interpolation
310
                    if self.default_pad_value == 'minimum':
311
                        default_value = tensor.min().item()
312
                    elif self.default_pad_value == 'mean':
313
                        default_value = get_borders_mean(
314
                            sitk_image, filter_otsu=False)
315
                    elif self.default_pad_value == 'otsu':
316
                        default_value = get_borders_mean(
317
                            sitk_image, filter_otsu=True)
318
                    else:
319
                        default_value = self.default_pad_value
320
                transformed_tensor = self.apply_affine_transform(
321
                    sitk_image,
322
                    transform,
323
                    interpolation,
324
                    default_value,
325
                )
326
                transformed_tensors.append(transformed_tensor)
327
            image.set_data(torch.stack(transformed_tensors))
328
        return subject
329
330
    def apply_affine_transform(
331
            self,
332
            sitk_image: sitk.Image,
333
            transform: sitk.Transform,
334
            interpolation: str,
335
            default_value: float,
336
            center_lps: Optional[TypeTripletFloat] = None,
337
            ) -> torch.Tensor:
338
        floating = reference = sitk_image
339
340
        resampler = sitk.ResampleImageFilter()
341
        resampler.SetInterpolator(self.get_sitk_interpolator(interpolation))
342
        resampler.SetReferenceImage(reference)
343
        resampler.SetDefaultPixelValue(float(default_value))
344
        resampler.SetOutputPixelType(sitk.sitkFloat32)
345
        resampler.SetTransform(transform)
346
        resampled = resampler.Execute(floating)
347
348
        np_array = sitk.GetArrayFromImage(resampled)
349
        np_array = np_array.transpose()  # ITK to NumPy
350
        tensor = torch.as_tensor(np_array)
351
        return tensor
352
353
354
# flake8: noqa: E201, E203, E243
355
def get_borders_mean(image, filter_otsu=True):
356
    # pylint: disable=bad-whitespace
357
    array = sitk.GetArrayViewFromImage(image)
358
    borders_tuple = (
359
        array[ 0,  :,  :],
360
        array[-1,  :,  :],
361
        array[ :,  0,  :],
362
        array[ :, -1,  :],
363
        array[ :,  :,  0],
364
        array[ :,  :, -1],
365
    )
366
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
367
    if not filter_otsu:
368
        return borders_flat.mean()
369
    borders_reshaped = borders_flat.reshape(1, 1, -1)
370
    borders_image = sitk.GetImageFromArray(borders_reshaped)
371
    otsu = sitk.OtsuThresholdImageFilter()
372
    otsu.Execute(borders_image)
373
    threshold = otsu.GetThreshold()
374
    values = borders_flat[borders_flat < threshold]
375
    if values.any():
376
        default_value = values.mean()
377
    else:
378
        default_value = borders_flat.mean()
379
    return default_value
380
381
def _parse_scales_isotropic(scales, isotropic):
382
    scales = to_tuple(scales)
383
    if isotropic and len(scales) in (3, 6):
384
        message = (
385
            'If "isotropic" is True, the value for "scales" must have'
386
            f' length 1 or 2, but "{scales}" was passed'
387
        )
388
        raise ValueError(message)
389
390
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
391
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
392
        return value
393
    message = (
394
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
395
        ' or a number'
396
    )
397
    raise ValueError(message)
398