Passed
Pull Request — master (#346)
by Fernando
01:21
created

RandomAffine.parse_params()   C

Complexity

Conditions 9

Size

Total Lines 24
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 19
nop 5
dl 0
loc 24
rs 6.6666
c 0
b 0
f 0
1
from numbers import Number
2
from typing import Tuple, Optional, List, Union
3
import torch
4
import numpy as np
5
import SimpleITK as sitk
6
from ....data.subject import Subject
7
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
8
from ....torchio import (
9
    INTENSITY,
10
    DATA,
11
    AFFINE,
12
    TYPE,
13
    TypeRangeFloat,
14
    TypeSextetFloat,
15
    TypeTripletFloat,
16
)
17
from ... import SpatialTransform
18
from .. import Interpolation, get_sitk_interpolator
19
from .. import RandomTransform
20
21
22
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
23
24
25
class RandomAffine(RandomTransform, SpatialTransform):
26
    r"""Random affine transformation.
27
28
    Args:
29
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
30
            scaling ranges.
31
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
32
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
33
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
34
            making the objects inside look twice as small while preserving
35
            the physical size and position of the image bounds.
36
            If only one value :math:`x` is provided,
37
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
38
            If two values :math:`(a, b)` are provided,
39
            then :math:`s_i \sim \mathcal{U}(a, b)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
43
            rotation ranges in degrees.
44
            Rotation angles around each axis are
45
            :math:`(\theta_1, \theta_2, \theta_3)`,
46
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
47
            If only one value :math:`x` is provided,
48
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
49
            If two values :math:`(a, b)` are provided,
50
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
51
            If three values :math:`(x_1, x_2, x_3)` are provided,
52
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
53
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
54
            translation ranges in mm.
55
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
56
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
57
            If only one value :math:`x` is provided,
58
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
59
            If two values :math:`(a, b)` are provided,
60
            then :math:`t_i \sim \mathcal{U}(a, b)`.
61
            If three values :math:`(x_1, x_2, x_3)` are provided,
62
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
63
        isotropic: If ``True``, the scaling factor along all dimensions is the
64
            same, i.e. :math:`s_1 = s_2 = s_3`.
65
        center: If ``'image'``, rotations and scaling will be performed around
66
            the image center. If ``'origin'``, rotations and scaling will be
67
            performed around the origin in world coordinates.
68
        default_pad_value: As the image is rotated, some values near the
69
            borders will be undefined.
70
            If ``'minimum'``, the fill value will be the image minimum.
71
            If ``'mean'``, the fill value is the mean of the border values.
72
            If ``'otsu'``, the fill value is the mean of the values at the
73
            border that lie under an
74
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
75
            If it is a number, that value will be used.
76
        image_interpolation: See :ref:`Interpolation`.
77
        p: Probability that this transform will be applied.
78
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
79
        keys: See :py:class:`~torchio.transforms.Transform`.
80
81
    Example:
82
        >>> import torchio as tio
83
        >>> subject = tio.datasets.Colin27()
84
        >>> transform = tio.RandomAffine(
85
        ...     scales=(0.9, 1.2),
86
        ...     degrees=(10),
87
        ...     isotropic=False,
88
        ...     default_pad_value='otsu',
89
        ...     image_interpolation='bspline',
90
        ... )
91
        >>> transformed = transform(subject)
92
93
    From the command line::
94
95
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "degrees=30 default_pad_value=minimum" --seed 42 affine_min.nii.gz
96
97
    """
98
    def __init__(
99
            self,
100
            scales: TypeOneToSixFloat = 0.1,
101
            degrees: TypeOneToSixFloat = 10,
102
            translation: TypeOneToSixFloat = 0,
103
            isotropic: bool = False,
104
            center: str = 'image',
105
            default_pad_value: Union[str, float] = 'otsu',
106
            image_interpolation: str = 'linear',
107
            p: float = 1,
108
            seed: Optional[int] = None,
109
            keys: Optional[List[str]] = None,
110
            ):
111
        super().__init__(p=p, seed=seed, keys=keys)
112
        self.isotropic = isotropic
113
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
114
        self.degrees = self.parse_params(degrees, 0, 'degrees')
115
        self.translation = self.parse_params(translation, 0, 'translation')
116
        if center not in ('image', 'origin'):
117
            message = (
118
                'Center argument must be "image" or "origin",'
119
                f' not "{center}"'
120
            )
121
            raise ValueError(message)
122
        self.use_image_center = center == 'image'
123
        self.default_pad_value = self.parse_default_value(default_pad_value)
124
        self.interpolation = self.parse_interpolation(image_interpolation)
125
126
    @staticmethod
127
    def to_range(n, around):
128
        if around is None:
129
            return 0, n
130
        else:
131
            return around - n, around + n
132
133
    def parse_params(self, params, around, name, **kwargs):
134
        params = to_tuple(params)
135
        length = len(params)
136
        if name == 'scales' and self.isotropic and length in (3, 6):
137
            message = (
138
                f'If "isotropic" is True, the value for "{name}" must have'
139
                f' length 1 or 2, but "{params}" was passed'
140
            )
141
            raise ValueError(message)
142
        if len(params) == 1 or len(params) == 2:  # d or (a, b)
143
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
144
        if len(params) == 3:  # (a, b, c)
145
            items = [self.to_range(n, around) for n in params]
146
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
147
            params = [n for prange in items for n in prange]
148
        if len(params) != 6:
149
            message = (
150
                f'If "{name}" is a sequence, it must have length 2, 3 or 6,'
151
                f' not {len(params)}'
152
            )
153
            raise ValueError(message)
154
        for param_range in zip(params[::2], params[1::2]):
155
            self.parse_range(param_range, name, **kwargs)
156
        return tuple(params)
157
158
    @staticmethod
159
    def parse_default_value(value: Union[str, float]) -> Union[str, float]:
160
        if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
161
            return value
162
        message = (
163
            'Value for default_pad_value must be "minimum", "otsu", "mean"'
164
            ' or a number'
165
        )
166
        raise ValueError(message)
167
168
    def get_params(
169
            self,
170
            scales,
171
            degrees,
172
            translation,
173
            isotropic,
174
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
175
        scaling_params = self.sample_uniform_sextet(scales)
176
        if isotropic:
177
            scaling_params.fill_(scaling_params[0])
178
        rotation_params = self.sample_uniform_sextet(degrees)
179
        translation_params = self.sample_uniform_sextet(translation)
180
        return scaling_params, rotation_params, translation_params
181
182
    @staticmethod
183
    def get_scaling_transform(
184
            scaling_params: List[float],
185
            center_lps: Optional[TypeTripletFloat] = None,
186
            ) -> sitk.ScaleTransform:
187
        # scaling_params are inverted so that they are more intuitive
188
        # For example, 1.5 means the objects look 1.5 times larger
189
        transform = sitk.ScaleTransform(3)
190
        scaling_params = 1 / np.array(scaling_params)
191
        transform.SetScale(scaling_params)
192
        if center_lps is not None:
193
            transform.SetCenter(center_lps)
194
        return transform
195
196
    @staticmethod
197
    def get_rotation_transform(
198
            degrees: List[float],
199
            translation: List[float],
200
            center_lps: Optional[TypeTripletFloat] = None,
201
            ) -> sitk.Euler3DTransform:
202
        transform = sitk.Euler3DTransform()
203
        radians = np.radians(degrees)
204
        transform.SetRotation(*radians)
205
        transform.SetTranslation(translation)
206
        if center_lps is not None:
207
            transform.SetCenter(center_lps)
208
        return transform
209
210
    def apply_transform(self, subject: Subject) -> Subject:
211
        subject.check_consistent_spatial_shape()
212
        scaling_params, rotation_params, translation_params = self.get_params(
213
            self.scales,
214
            self.degrees,
215
            self.translation,
216
            self.isotropic,
217
        )
218
        for image in self.get_images(subject):
219
            if image[TYPE] != INTENSITY:
220
                interpolation = Interpolation.NEAREST
221
            else:
222
                interpolation = self.interpolation
223
224
            if image.is_2d():
225
                scaling_params[2] = 1
226
                rotation_params[:-1] = 0
227
228
            if self.use_image_center:
229
                center = image.get_center(lps=True)
230
            else:
231
                center = None
232
233
            transformed_tensors = []
234
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
235
                transformed_tensor = self.apply_affine_transform(
236
                    tensor,
237
                    image[AFFINE],
238
                    scaling_params.tolist(),
239
                    rotation_params.tolist(),
240
                    translation_params.tolist(),
241
                    interpolation,
242
                    center_lps=center,
243
                )
244
                transformed_tensors.append(transformed_tensor)
245
            image[DATA] = torch.stack(transformed_tensors)
246
        random_parameters_dict = {
247
            'scaling': scaling_params,
248
            'rotation': rotation_params,
249
            'translation': translation_params,
250
        }
251
        subject.add_transform(self, random_parameters_dict)
252
        return subject
253
254
    def apply_affine_transform(
255
            self,
256
            tensor: torch.Tensor,
257
            affine: np.ndarray,
258
            scaling_params: List[float],
259
            rotation_params: List[float],
260
            translation_params: List[float],
261
            interpolation: Interpolation,
262
            center_lps: Optional[TypeTripletFloat] = None,
263
            ) -> torch.Tensor:
264
        assert tensor.ndim == 3
265
266
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
267
        floating = reference = image
268
269
        scaling_transform = self.get_scaling_transform(
270
            scaling_params,
271
            center_lps=center_lps,
272
        )
273
        rotation_transform = self.get_rotation_transform(
274
            rotation_params,
275
            translation_params,
276
            center_lps=center_lps,
277
        )
278
279
        sitk_major_version = get_major_sitk_version()
280
        if sitk_major_version == 1:
281
            transform = sitk.Transform(3, sitk.sitkComposite)
282
            transform.AddTransform(scaling_transform)
283
            transform.AddTransform(rotation_transform)
284
        elif sitk_major_version == 2:
285
            transforms = [scaling_transform, rotation_transform]
286
            transform = sitk.CompositeTransform(transforms)
287
288
        if self.default_pad_value == 'minimum':
289
            default_value = tensor.min().item()
290
        elif self.default_pad_value == 'mean':
291
            default_value = get_borders_mean(image, filter_otsu=False)
292
        elif self.default_pad_value == 'otsu':
293
            default_value = get_borders_mean(image, filter_otsu=True)
294
        else:
295
            default_value = self.default_pad_value
296
297
        resampler = sitk.ResampleImageFilter()
298
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
299
        resampler.SetReferenceImage(reference)
300
        resampler.SetDefaultPixelValue(float(default_value))
301
        resampler.SetOutputPixelType(sitk.sitkFloat32)
302
        resampler.SetTransform(transform)
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
303
        resampled = resampler.Execute(floating)
304
305
        np_array = sitk.GetArrayFromImage(resampled)
306
        np_array = np_array.transpose()  # ITK to NumPy
307
        tensor = torch.from_numpy(np_array)
308
        return tensor
309
310
# flake8: noqa: E201, E203, E243
311
def get_borders_mean(image, filter_otsu=True):
312
    # pylint: disable=bad-whitespace
313
    array = sitk.GetArrayViewFromImage(image)
314
    borders_tuple = (
315
        array[ 0,  :,  :],
316
        array[-1,  :,  :],
317
        array[ :,  0,  :],
318
        array[ :, -1,  :],
319
        array[ :,  :,  0],
320
        array[ :,  :, -1],
321
    )
322
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
323
    if not filter_otsu:
324
        return borders_flat.mean()
325
    borders_reshaped = borders_flat.reshape(1, 1, -1)
326
    borders_image = sitk.GetImageFromArray(borders_reshaped)
327
    otsu = sitk.OtsuThresholdImageFilter()
328
    otsu.Execute(borders_image)
329
    threshold = otsu.GetThreshold()
330
    values = borders_flat[borders_flat < threshold]
331
    if values.any():
332
        default_value = values.mean()
333
    else:
334
        default_value = borders_flat.mean()
335
    return default_value
336