Passed
Pull Request — master (#353)
by Fernando
01:11
created

Affine.inverse()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 4
nop 1
dl 0
loc 4
rs 10
c 0
b 0
f 0
1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, List, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import Interpolation, get_sitk_interpolator
22
from .. import RandomTransform
23
24
25
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
26
27
28
class RandomAffine(RandomTransform, SpatialTransform):
29
    r"""Random affine transformation.
30
31
    Args:
32
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
33
            scaling ranges.
34
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
35
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
36
            If two values :math:`(a, b)` are provided,
37
            then :math:`s_i \sim \mathcal{U}(a, b)`.
38
            If only one value :math:`x` is provided,
39
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
43
            making the objects inside look twice as small while preserving
44
            the physical size and position of the image bounds.
45
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
46
            rotation ranges in degrees.
47
            Rotation angles around each axis are
48
            :math:`(\theta_1, \theta_2, \theta_3)`,
49
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
50
            If two values :math:`(a, b)` are provided,
51
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
52
            If only one value :math:`x` is provided,
53
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
54
            If three values :math:`(x_1, x_2, x_3)` are provided,
55
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
56
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
57
            translation ranges in mm.
58
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
59
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
60
            If two values :math:`(a, b)` are provided,
61
            then :math:`t_i \sim \mathcal{U}(a, b)`.
62
            If only one value :math:`x` is provided,
63
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
64
            If three values :math:`(x_1, x_2, x_3)` are provided,
65
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
66
        isotropic: If ``True``, the scaling factor along all dimensions is the
67
            same, i.e. :math:`s_1 = s_2 = s_3`.
68
        center: If ``'image'``, rotations and scaling will be performed around
69
            the image center. If ``'origin'``, rotations and scaling will be
70
            performed around the origin in world coordinates.
71
        default_pad_value: As the image is rotated, some values near the
72
            borders will be undefined.
73
            If ``'minimum'``, the fill value will be the image minimum.
74
            If ``'mean'``, the fill value is the mean of the border values.
75
            If ``'otsu'``, the fill value is the mean of the values at the
76
            border that lie under an
77
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
78
            If it is a number, that value will be used.
79
        image_interpolation: See :ref:`Interpolation`.
80
        p: Probability that this transform will be applied.
81
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
82
        keys: See :py:class:`~torchio.transforms.Transform`.
83
84
    Example:
85
        >>> import torchio as tio
86
        >>> subject = tio.datasets.Colin27()
87
        >>> transform = tio.RandomAffine(
88
        ...     scales=(0.9, 1.2),
89
        ...     degrees=(10),
90
        ...     isotropic=False,
91
        ...     default_pad_value='otsu',
92
        ...     image_interpolation='bspline',
93
        ... )
94
        >>> transformed = transform(subject)
95
96
    From the command line::
97
98
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "degrees=30 default_pad_value=minimum" --seed 42 affine_min.nii.gz
99
100
    """
101
    def __init__(
102
            self,
103
            scales: TypeOneToSixFloat = 0.1,
104
            degrees: TypeOneToSixFloat = 10,
105
            translation: TypeOneToSixFloat = 0,
106
            isotropic: bool = False,
107
            center: str = 'image',
108
            default_pad_value: Union[str, float] = 'otsu',
109
            image_interpolation: str = 'linear',
110
            p: float = 1,
111
            keys: Optional[List[str]] = None,
112
            ):
113
        super().__init__(p=p, keys=keys)
114
        self.isotropic = isotropic
115
        _parse_scales_isotropic(scales, isotropic)
116
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
117
        self.degrees = self.parse_params(degrees, 0, 'degrees')
118
        self.translation = self.parse_params(translation, 0, 'translation')
119
        if center not in ('image', 'origin'):
120
            message = (
121
                'Center argument must be "image" or "origin",'
122
                f' not "{center}"'
123
            )
124
            raise ValueError(message)
125
        self.center = center
126
        self.default_pad_value = _parse_default_value(default_pad_value)
127
        self.interpolation = self.parse_interpolation(image_interpolation)
128
129
    def get_params(
130
            self,
131
            scales: TypeSextetFloat,
132
            degrees: TypeSextetFloat,
133
            translation: TypeSextetFloat,
134
            isotropic: bool,
135
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
136
        scaling_params = self.sample_uniform_sextet(scales)
137
        if isotropic:
138
            scaling_params.fill_(scaling_params[0])
139
        rotation_params = self.sample_uniform_sextet(degrees)
140
        translation_params = self.sample_uniform_sextet(translation)
141
        return scaling_params, rotation_params, translation_params
142
143
    def apply_transform(self, subject: Subject) -> Subject:
144
        subject.check_consistent_spatial_shape()
145
        scaling_params, rotation_params, translation_params = self.get_params(
146
            self.scales,
147
            self.degrees,
148
            self.translation,
149
            self.isotropic,
150
        )
151
        arguments = dict(
152
            scales=scaling_params.tolist(),
153
            degrees=rotation_params.tolist(),
154
            translation=translation_params.tolist(),
155
            center=self.center,
156
            default_pad_value=self.default_pad_value,
157
            image_interpolation=self.interpolation,
158
        )
159
        transform = Affine(**arguments)
160
        transformed = transform(subject)
161
        transformed.add_transform(transform, arguments)
162
        return transformed
163
164
165
class Affine(SpatialTransform):
166
    r"""Apply affine transformation.
167
168
    Args:
169
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
170
            scaling values along each dimension.
171
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
172
            rotation around each axis.
173
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
174
            translation in mm along each axis.
175
        center: If ``'image'``, rotations and scaling will be performed around
176
            the image center. If ``'origin'``, rotations and scaling will be
177
            performed around the origin in world coordinates.
178
        default_pad_value: As the image is rotated, some values near the
179
            borders will be undefined.
180
            If ``'minimum'``, the fill value will be the image minimum.
181
            If ``'mean'``, the fill value is the mean of the border values.
182
            If ``'otsu'``, the fill value is the mean of the values at the
183
            border that lie under an
184
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
185
            If it is a number, that value will be used.
186
        image_interpolation: See :ref:`Interpolation`.
187
        keys: See :py:class:`~torchio.transforms.Transform`.
188
    """
189
    def __init__(
190
            self,
191
            scales: TypeTripletFloat,
192
            degrees: TypeTripletFloat,
193
            translation: TypeTripletFloat,
194
            center: str = 'image',
195
            default_pad_value: Union[str, float] = 'otsu',
196
            image_interpolation: str = 'linear',
197
            keys: Optional[List[str]] = None,
198
            ):
199
        super().__init__(keys=keys)
200
        self.scales = self.parse_params(
201
            scales,
202
            None,
203
            'scales',
204
            make_ranges=False,
205
            min_constraint=0,
206
        )
207
        self.degrees = self.parse_params(
208
            degrees,
209
            None,
210
            'degrees',
211
            make_ranges=False,
212
        )
213
        self.translation = self.parse_params(translation, 0, 'translation', make_ranges=False,)
214
        if center not in ('image', 'origin'):
215
            message = (
216
                'Center argument must be "image" or "origin",'
217
                f' not "{center}"'
218
            )
219
            raise ValueError(message)
220
        self.use_image_center = center == 'image'
221
        self.default_pad_value = _parse_default_value(default_pad_value)
222
        self.interpolation = self.parse_interpolation(image_interpolation)
223
        self.invert_transform = False
224
225
    @staticmethod
226
    def get_scaling_transform(
227
            scaling_params: List[float],
228
            center_lps: Optional[TypeTripletFloat] = None,
229
            ) -> sitk.ScaleTransform:
230
        # scaling_params are inverted so that they are more intuitive
231
        # For example, 1.5 means the objects look 1.5 times larger
232
        transform = sitk.ScaleTransform(3)
233
        scaling_params = 1 / np.array(scaling_params)
234
        transform.SetScale(scaling_params)
235
        if center_lps is not None:
236
            transform.SetCenter(center_lps)
237
        return transform
238
239
    @staticmethod
240
    def get_rotation_transform(
241
            degrees: List[float],
242
            translation: List[float],
243
            center_lps: Optional[TypeTripletFloat] = None,
244
            ) -> sitk.Euler3DTransform:
245
        transform = sitk.Euler3DTransform()
246
        radians = np.radians(degrees)
247
        transform.SetRotation(*radians)
248
        transform.SetTranslation(translation)
249
        if center_lps is not None:
250
            transform.SetCenter(center_lps)
251
        return transform
252
253
    def apply_transform(self, subject: Subject) -> Subject:
254
        scaling_params = np.array(self.scales).copy()
255
        rotation_params = np.array(self.degrees).copy()
256
        translation_params = np.array(self.translation).copy()
257
        subject.check_consistent_spatial_shape()
258
        for image in self.get_images(subject):
259
            if image[TYPE] != INTENSITY:
260
                interpolation = Interpolation.NEAREST
261
            else:
262
                interpolation = self.interpolation
263
264
            if image.is_2d():
265
                scaling_params[2] = 1
266
                rotation_params[:-1] = 0
267
268
            if self.use_image_center:
269
                center = image.get_center(lps=True)
270
            else:
271
                center = None
272
273
            transformed_tensors = []
274
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
275
                transformed_tensor = self.apply_affine_transform(
276
                    tensor,
277
                    image[AFFINE],
278
                    scaling_params.tolist(),
279
                    rotation_params.tolist(),
280
                    translation_params.tolist(),
281
                    interpolation,
282
                    center_lps=center,
283
                )
284
                transformed_tensors.append(transformed_tensor)
285
            image[DATA] = torch.stack(transformed_tensors)
286
        return subject
287
288
    def apply_affine_transform(
289
            self,
290
            tensor: torch.Tensor,
291
            affine: np.ndarray,
292
            scaling_params: List[float],
293
            rotation_params: List[float],
294
            translation_params: List[float],
295
            interpolation: Interpolation,
296
            center_lps: Optional[TypeTripletFloat] = None,
297
            ) -> torch.Tensor:
298
        assert tensor.ndim == 3
299
300
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
301
        floating = reference = image
302
303
        scaling_transform = self.get_scaling_transform(
304
            scaling_params,
305
            center_lps=center_lps,
306
        )
307
        rotation_transform = self.get_rotation_transform(
308
            rotation_params,
309
            translation_params,
310
            center_lps=center_lps,
311
        )
312
313
        sitk_major_version = get_major_sitk_version()
314
        if sitk_major_version == 1:
315
            transform = sitk.Transform(3, sitk.sitkComposite)
316
            transform.AddTransform(scaling_transform)
317
            transform.AddTransform(rotation_transform)
318
        elif sitk_major_version == 2:
319
            transforms = [scaling_transform, rotation_transform]
320
            transform = sitk.CompositeTransform(transforms)
321
322
        if self.invert_transform:
323
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
324
325
        if self.default_pad_value == 'minimum':
326
            default_value = tensor.min().item()
327
        elif self.default_pad_value == 'mean':
328
            default_value = get_borders_mean(image, filter_otsu=False)
329
        elif self.default_pad_value == 'otsu':
330
            default_value = get_borders_mean(image, filter_otsu=True)
331
        else:
332
            default_value = self.default_pad_value
333
334
        resampler = sitk.ResampleImageFilter()
335
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
336
        resampler.SetReferenceImage(reference)
337
        resampler.SetDefaultPixelValue(float(default_value))
338
        resampler.SetOutputPixelType(sitk.sitkFloat32)
339
        resampler.SetTransform(transform)
340
        resampled = resampler.Execute(floating)
341
342
        np_array = sitk.GetArrayFromImage(resampled)
343
        np_array = np_array.transpose()  # ITK to NumPy
344
        tensor = torch.from_numpy(np_array)
345
        return tensor
346
347
    def inverse(self):
348
        new = copy.deepcopy(self)
349
        new.invert_transform = not self.invert_transform
350
        return new
351
352
353
# flake8: noqa: E201, E203, E243
354
def get_borders_mean(image, filter_otsu=True):
355
    # pylint: disable=bad-whitespace
356
    array = sitk.GetArrayViewFromImage(image)
357
    borders_tuple = (
358
        array[ 0,  :,  :],
359
        array[-1,  :,  :],
360
        array[ :,  0,  :],
361
        array[ :, -1,  :],
362
        array[ :,  :,  0],
363
        array[ :,  :, -1],
364
    )
365
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
366
    if not filter_otsu:
367
        return borders_flat.mean()
368
    borders_reshaped = borders_flat.reshape(1, 1, -1)
369
    borders_image = sitk.GetImageFromArray(borders_reshaped)
370
    otsu = sitk.OtsuThresholdImageFilter()
371
    otsu.Execute(borders_image)
372
    threshold = otsu.GetThreshold()
373
    values = borders_flat[borders_flat < threshold]
374
    if values.any():
375
        default_value = values.mean()
376
    else:
377
        default_value = borders_flat.mean()
378
    return default_value
379
380
def _parse_scales_isotropic(scales, isotropic):
381
    params = to_tuple(scales)
382
    if isotropic and len(scales) in (3, 6):
383
        message = (
384
            'If "isotropic" is True, the value for "scales" must have'
385
            f' length 1 or 2, but "{scales}" was passed'
386
        )
387
        raise ValueError(message)
388
389
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
390
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
391
        return value
392
    message = (
393
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
394
        ' or a number'
395
    )
396
    raise ValueError(message)
397