Passed
Pull Request — master (#353)
by Fernando
01:17
created

Affine.apply_transform()   B

Complexity

Conditions 6

Size

Total Lines 34
Code Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 28
nop 2
dl 0
loc 34
rs 8.2746
c 0
b 0
f 0
1
import copy
2
from numbers import Number
3
from typing import Tuple, Optional, Sequence, Union
4
5
import torch
6
import numpy as np
7
import SimpleITK as sitk
8
9
from ....data.subject import Subject
10
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
11
from ....torchio import (
12
    INTENSITY,
13
    DATA,
14
    AFFINE,
15
    TYPE,
16
    TypeRangeFloat,
17
    TypeSextetFloat,
18
    TypeTripletFloat,
19
)
20
from ... import SpatialTransform
21
from .. import Interpolation, get_sitk_interpolator
22
from .. import RandomTransform
23
24
25
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
26
27
28
class RandomAffine(RandomTransform, SpatialTransform):
29
    r"""Random affine transformation.
30
31
    Args:
32
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
33
            scaling ranges.
34
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
35
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
36
            If two values :math:`(a, b)` are provided,
37
            then :math:`s_i \sim \mathcal{U}(a, b)`.
38
            If only one value :math:`x` is provided,
39
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
43
            making the objects inside look twice as small while preserving
44
            the physical size and position of the image bounds.
45
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
46
            rotation ranges in degrees.
47
            Rotation angles around each axis are
48
            :math:`(\theta_1, \theta_2, \theta_3)`,
49
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
50
            If two values :math:`(a, b)` are provided,
51
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
52
            If only one value :math:`x` is provided,
53
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
54
            If three values :math:`(x_1, x_2, x_3)` are provided,
55
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
56
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
57
            translation ranges in mm.
58
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
59
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
60
            If two values :math:`(a, b)` are provided,
61
            then :math:`t_i \sim \mathcal{U}(a, b)`.
62
            If only one value :math:`x` is provided,
63
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
64
            If three values :math:`(x_1, x_2, x_3)` are provided,
65
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
66
        isotropic: If ``True``, the scaling factor along all dimensions is the
67
            same, i.e. :math:`s_1 = s_2 = s_3`.
68
        center: If ``'image'``, rotations and scaling will be performed around
69
            the image center. If ``'origin'``, rotations and scaling will be
70
            performed around the origin in world coordinates.
71
        default_pad_value: As the image is rotated, some values near the
72
            borders will be undefined.
73
            If ``'minimum'``, the fill value will be the image minimum.
74
            If ``'mean'``, the fill value is the mean of the border values.
75
            If ``'otsu'``, the fill value is the mean of the values at the
76
            border that lie under an
77
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
78
            If it is a number, that value will be used.
79
        image_interpolation: See :ref:`Interpolation`.
80
        p: Probability that this transform will be applied.
81
        keys: See :py:class:`~torchio.transforms.Transform`.
82
83
    Example:
84
        >>> import torchio as tio
85
        >>> subject = tio.datasets.Colin27()
86
        >>> transform = tio.RandomAffine(
87
        ...     scales=(0.9, 1.2),
88
        ...     degrees=10,
89
        ...     isotropic=True,
90
        ...     image_interpolation='nearest',
91
        ... )
92
        >>> transformed = transform(subject)
93
94
    From the command line::
95
96
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "scales=(0.9, 1.2) degrees=10 isotropic=True image_interpolation=nearest" --seed 42 affine_min.nii.gz
97
98
    """
99
    def __init__(
100
            self,
101
            scales: TypeOneToSixFloat = 0.1,
102
            degrees: TypeOneToSixFloat = 10,
103
            translation: TypeOneToSixFloat = 0,
104
            isotropic: bool = False,
105
            center: str = 'image',
106
            default_pad_value: Union[str, float] = 'minimum',
107
            image_interpolation: str = 'linear',
108
            p: float = 1,
109
            keys: Optional[Sequence[str]] = None,
110
            ):
111
        super().__init__(p=p, keys=keys)
112
        self.isotropic = isotropic
113
        _parse_scales_isotropic(scales, isotropic)
114
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
115
        self.degrees = self.parse_params(degrees, 0, 'degrees')
116
        self.translation = self.parse_params(translation, 0, 'translation')
117
        if center not in ('image', 'origin'):
118
            message = (
119
                'Center argument must be "image" or "origin",'
120
                f' not "{center}"'
121
            )
122
            raise ValueError(message)
123
        self.center = center
124
        self.default_pad_value = _parse_default_value(default_pad_value)
125
        self.image_interpolation = self.parse_interpolation(image_interpolation)
126
127
    def get_params(
128
            self,
129
            scales: TypeSextetFloat,
130
            degrees: TypeSextetFloat,
131
            translation: TypeSextetFloat,
132
            isotropic: bool,
133
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
134
        scaling_params = self.sample_uniform_sextet(scales)
135
        if isotropic:
136
            scaling_params.fill_(scaling_params[0])
137
        rotation_params = self.sample_uniform_sextet(degrees)
138
        translation_params = self.sample_uniform_sextet(translation)
139
        return scaling_params, rotation_params, translation_params
140
141
    def apply_transform(self, subject: Subject) -> Subject:
142
        subject.check_consistent_spatial_shape()
143
        scaling_params, rotation_params, translation_params = self.get_params(
144
            self.scales,
145
            self.degrees,
146
            self.translation,
147
            self.isotropic,
148
        )
149
        arguments = dict(
150
            scales=scaling_params.tolist(),
151
            degrees=rotation_params.tolist(),
152
            translation=translation_params.tolist(),
153
            center=self.center,
154
            default_pad_value=self.default_pad_value,
155
            image_interpolation=self.image_interpolation,
156
        )
157
        transform = Affine(**arguments)
158
        transformed = transform(subject)
159
        return transformed
160
161
162
class Affine(SpatialTransform):
163
    r"""Apply affine transformation.
164
165
    Args:
166
        scales: Tuple :math:`(s_1, s_2, s_3)` defining the
167
            scaling values along each dimension.
168
        degrees: Tuple :math:`(\theta_1, \theta_2, \theta_3)` defining the
169
            rotation around each axis.
170
        translation: Tuple :math:`(t_1, t_2, t_3)` defining the
171
            translation in mm along each axis.
172
        center: If ``'image'``, rotations and scaling will be performed around
173
            the image center. If ``'origin'``, rotations and scaling will be
174
            performed around the origin in world coordinates.
175
        default_pad_value: As the image is rotated, some values near the
176
            borders will be undefined.
177
            If ``'minimum'``, the fill value will be the image minimum.
178
            If ``'mean'``, the fill value is the mean of the border values.
179
            If ``'otsu'``, the fill value is the mean of the values at the
180
            border that lie under an
181
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
182
            If it is a number, that value will be used.
183
        image_interpolation: See :ref:`Interpolation`.
184
        keys: See :py:class:`~torchio.transforms.Transform`.
185
    """
186
    def __init__(
187
            self,
188
            scales: TypeTripletFloat,
189
            degrees: TypeTripletFloat,
190
            translation: TypeTripletFloat,
191
            center: str = 'image',
192
            default_pad_value: Union[str, float] = 'minimum',
193
            image_interpolation: str = 'linear',
194
            keys: Optional[Sequence[str]] = None,
195
            ):
196
        super().__init__(keys=keys)
197
        self.scales = self.parse_params(
198
            scales,
199
            None,
200
            'scales',
201
            make_ranges=False,
202
            min_constraint=0,
203
        )
204
        self.degrees = self.parse_params(
205
            degrees,
206
            None,
207
            'degrees',
208
            make_ranges=False,
209
        )
210
        self.translation = self.parse_params(
211
            translation,
212
            None,
213
            'translation',
214
            make_ranges=False,
215
        )
216
        if center not in ('image', 'origin'):
217
            message = (
218
                'Center argument must be "image" or "origin",'
219
                f' not "{center}"'
220
            )
221
            raise ValueError(message)
222
        self.center = center
223
        self.use_image_center = center == 'image'
224
        self.default_pad_value = _parse_default_value(default_pad_value)
225
        self.image_interpolation = self.parse_interpolation(image_interpolation)
226
        self.invert_transform = False
227
        self.args_names = (
228
            'scales',
229
            'degrees',
230
            'translation',
231
            'center',
232
            'default_pad_value',
233
            'image_interpolation',
234
        )
235
236
    @staticmethod
237
    def get_scaling_transform(
238
            scaling_params: Sequence[float],
239
            center_lps: Optional[TypeTripletFloat] = None,
240
            ) -> sitk.ScaleTransform:
241
        # scaling_params are inverted so that they are more intuitive
242
        # For example, 1.5 means the objects look 1.5 times larger
243
        transform = sitk.ScaleTransform(3)
244
        scaling_params = 1 / np.array(scaling_params)
245
        transform.SetScale(scaling_params)
246
        if center_lps is not None:
247
            transform.SetCenter(center_lps)
248
        return transform
249
250
    @staticmethod
251
    def get_rotation_transform(
252
            degrees: Sequence[float],
253
            translation: Sequence[float],
254
            center_lps: Optional[TypeTripletFloat] = None,
255
            ) -> sitk.Euler3DTransform:
256
        transform = sitk.Euler3DTransform()
257
        radians = np.radians(degrees)
258
        transform.SetRotation(*radians)
259
        transform.SetTranslation(translation)
260
        if center_lps is not None:
261
            transform.SetCenter(center_lps)
262
        return transform
263
264
    def apply_transform(self, subject: Subject) -> Subject:
265
        scaling_params = np.array(self.scales).copy()
266
        rotation_params = np.array(self.degrees).copy()
267
        translation_params = np.array(self.translation).copy()
268
        subject.check_consistent_spatial_shape()
269
        for image in self.get_images(subject):
270
            if image[TYPE] != INTENSITY:
271
                interpolation = 'nearest'
272
            else:
273
                interpolation = self.image_interpolation
274
275
            if image.is_2d():
276
                scaling_params[2] = 1
277
                rotation_params[:-1] = 0
278
279
            if self.use_image_center:
280
                center = image.get_center(lps=True)
281
            else:
282
                center = None
283
284
            transformed_tensors = []
285
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
286
                transformed_tensor = self.apply_affine_transform(
287
                    tensor,
288
                    image[AFFINE],
289
                    scaling_params.tolist(),
290
                    rotation_params.tolist(),
291
                    translation_params.tolist(),
292
                    interpolation,
293
                    center_lps=center,
294
                )
295
                transformed_tensors.append(transformed_tensor)
296
            image[DATA] = torch.stack(transformed_tensors)
297
        return subject
298
299
    def apply_affine_transform(
300
            self,
301
            tensor: torch.Tensor,
302
            affine: np.ndarray,
303
            scaling_params: Sequence[float],
304
            rotation_params: Sequence[float],
305
            translation_params: Sequence[float],
306
            interpolation: Interpolation,
307
            center_lps: Optional[TypeTripletFloat] = None,
308
            ) -> torch.Tensor:
309
        assert tensor.ndim == 3
310
311
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
312
        floating = reference = image
313
314
        scaling_transform = self.get_scaling_transform(
315
            scaling_params,
316
            center_lps=center_lps,
317
        )
318
        rotation_transform = self.get_rotation_transform(
319
            rotation_params,
320
            translation_params,
321
            center_lps=center_lps,
322
        )
323
324
        sitk_major_version = get_major_sitk_version()
325
        if sitk_major_version == 1:
326
            transform = sitk.Transform(3, sitk.sitkComposite)
327
            transform.AddTransform(scaling_transform)
328
            transform.AddTransform(rotation_transform)
329
        elif sitk_major_version == 2:
330
            transforms = [scaling_transform, rotation_transform]
331
            transform = sitk.CompositeTransform(transforms)
332
333
        if self.invert_transform:
334
            transform = transform.GetInverse()
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
335
336
        if self.default_pad_value == 'minimum':
337
            default_value = tensor.min().item()
338
        elif self.default_pad_value == 'mean':
339
            default_value = get_borders_mean(image, filter_otsu=False)
340
        elif self.default_pad_value == 'otsu':
341
            default_value = get_borders_mean(image, filter_otsu=True)
342
        else:
343
            default_value = self.default_pad_value
344
345
        resampler = sitk.ResampleImageFilter()
346
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
347
        resampler.SetReferenceImage(reference)
348
        resampler.SetDefaultPixelValue(float(default_value))
349
        resampler.SetOutputPixelType(sitk.sitkFloat32)
350
        resampler.SetTransform(transform)
351
        resampled = resampler.Execute(floating)
352
353
        np_array = sitk.GetArrayFromImage(resampled)
354
        np_array = np_array.transpose()  # ITK to NumPy
355
        tensor = torch.from_numpy(np_array)
356
        return tensor
357
358
359
# flake8: noqa: E201, E203, E243
360
def get_borders_mean(image, filter_otsu=True):
361
    # pylint: disable=bad-whitespace
362
    array = sitk.GetArrayViewFromImage(image)
363
    borders_tuple = (
364
        array[ 0,  :,  :],
365
        array[-1,  :,  :],
366
        array[ :,  0,  :],
367
        array[ :, -1,  :],
368
        array[ :,  :,  0],
369
        array[ :,  :, -1],
370
    )
371
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
372
    if not filter_otsu:
373
        return borders_flat.mean()
374
    borders_reshaped = borders_flat.reshape(1, 1, -1)
375
    borders_image = sitk.GetImageFromArray(borders_reshaped)
376
    otsu = sitk.OtsuThresholdImageFilter()
377
    otsu.Execute(borders_image)
378
    threshold = otsu.GetThreshold()
379
    values = borders_flat[borders_flat < threshold]
380
    if values.any():
381
        default_value = values.mean()
382
    else:
383
        default_value = borders_flat.mean()
384
    return default_value
385
386
def _parse_scales_isotropic(scales, isotropic):
387
    params = to_tuple(scales)
388
    if isotropic and len(scales) in (3, 6):
389
        message = (
390
            'If "isotropic" is True, the value for "scales" must have'
391
            f' length 1 or 2, but "{scales}" was passed'
392
        )
393
        raise ValueError(message)
394
395
def _parse_default_value(value: Union[str, float]) -> Union[str, float]:
396
    if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
397
        return value
398
    message = (
399
        'Value for default_pad_value must be "minimum", "otsu", "mean"'
400
        ' or a number'
401
    )
402
    raise ValueError(message)
403