Passed
Pull Request — master (#346)
by Fernando
05:34 queued 04:10
created

RandomAffine.to_range()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
from numbers import Number
2
from typing import Tuple, Optional, List, Union
3
import torch
4
import numpy as np
5
import SimpleITK as sitk
6
from ....data.subject import Subject
7
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
8
from ....torchio import (
9
    INTENSITY,
10
    DATA,
11
    AFFINE,
12
    TYPE,
13
    TypeRangeFloat,
14
    TypeSextetFloat,
15
    TypeTripletFloat,
16
)
17
from ... import SpatialTransform
18
from .. import Interpolation, get_sitk_interpolator
19
from .. import RandomTransform
20
21
22
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
23
24
25
class RandomAffine(RandomTransform, SpatialTransform):
26
    r"""Random affine transformation.
27
28
    Args:
29
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
30
            scaling ranges.
31
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
32
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
33
            If two values :math:`(a, b)` are provided,
34
            then :math:`s_i \sim \mathcal{U}(a, b)`.
35
            If only one value :math:`x` is provided,
36
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
37
            If three values :math:`(x_1, x_2, x_3)` are provided,
38
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
39
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
40
            making the objects inside look twice as small while preserving
41
            the physical size and position of the image bounds.
42
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
43
            rotation ranges in degrees.
44
            Rotation angles around each axis are
45
            :math:`(\theta_1, \theta_2, \theta_3)`,
46
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
47
            If two values :math:`(a, b)` are provided,
48
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
49
            If only one value :math:`x` is provided,
50
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
51
            If three values :math:`(x_1, x_2, x_3)` are provided,
52
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
53
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
54
            translation ranges in mm.
55
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
56
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
57
            If two values :math:`(a, b)` are provided,
58
            then :math:`t_i \sim \mathcal{U}(a, b)`.
59
            If only one value :math:`x` is provided,
60
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
61
            If three values :math:`(x_1, x_2, x_3)` are provided,
62
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
63
        isotropic: If ``True``, the scaling factor along all dimensions is the
64
            same, i.e. :math:`s_1 = s_2 = s_3`.
65
        center: If ``'image'``, rotations and scaling will be performed around
66
            the image center. If ``'origin'``, rotations and scaling will be
67
            performed around the origin in world coordinates.
68
        default_pad_value: As the image is rotated, some values near the
69
            borders will be undefined.
70
            If ``'minimum'``, the fill value will be the image minimum.
71
            If ``'mean'``, the fill value is the mean of the border values.
72
            If ``'otsu'``, the fill value is the mean of the values at the
73
            border that lie under an
74
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
75
            If it is a number, that value will be used.
76
        image_interpolation: See :ref:`Interpolation`.
77
        p: Probability that this transform will be applied.
78
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
79
        keys: See :py:class:`~torchio.transforms.Transform`.
80
81
    Example:
82
        >>> import torchio as tio
83
        >>> subject = tio.datasets.Colin27()
84
        >>> transform = tio.RandomAffine(
85
        ...     scales=(0.9, 1.2),
86
        ...     degrees=(10),
87
        ...     isotropic=False,
88
        ...     default_pad_value='otsu',
89
        ...     image_interpolation='bspline',
90
        ... )
91
        >>> transformed = transform(subject)
92
93
    From the command line::
94
95
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "degrees=30 default_pad_value=minimum" --seed 42 affine_min.nii.gz
96
97
    """
98
    def __init__(
99
            self,
100
            scales: TypeOneToSixFloat = 0.1,
101
            degrees: TypeOneToSixFloat = 10,
102
            translation: TypeOneToSixFloat = 0,
103
            isotropic: bool = False,
104
            center: str = 'image',
105
            default_pad_value: Union[str, float] = 'otsu',
106
            image_interpolation: str = 'linear',
107
            p: float = 1,
108
            seed: Optional[int] = None,
109
            keys: Optional[List[str]] = None,
110
            ):
111
        super().__init__(p=p, seed=seed, keys=keys)
112
        self.isotropic = isotropic
113
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
114
        self.degrees = self.parse_params(degrees, 0, 'degrees')
115
        self.translation = self.parse_params(translation, 0, 'translation')
116
        if center not in ('image', 'origin'):
117
            message = (
118
                'Center argument must be "image" or "origin",'
119
                f' not "{center}"'
120
            )
121
            raise ValueError(message)
122
        self.use_image_center = center == 'image'
123
        self.default_pad_value = self.parse_default_value(default_pad_value)
124
        self.interpolation = self.parse_interpolation(image_interpolation)
125
126
    @staticmethod
127
    def to_range(n, around):
128
        if around is None:
129
            return 0, n
130
        else:
131
            return around - n, around + n
132
133
    def parse_params(self, params, around, name, **kwargs):
134
        params = to_tuple(params)
135
        length = len(params)
136
        if name == 'scales' and self.isotropic and length in (3, 6):
137
            message = (
138
                f'If "isotropic" is True, the value for "{name}" must have'
139
                f' length 1 or 2, but "{params}" was passed'
140
            )
141
            raise ValueError(message)
142
        if len(params) == 1 or len(params) == 2:  # d or (a, b)
143
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
144
        if len(params) == 3:  # (a, b, c)
145
            items = [self.to_range(n, around) for n in params]
146
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
147
            params = [n for prange in items for n in prange]
148
        if len(params) != 6:
149
            message = (
150
                f'If "{name}" is a sequence, it must have length 2, 3 or 6,'
151
                f' not {len(params)}'
152
            )
153
            raise ValueError(message)
154
        for param_range in zip(params[::2], params[1::2]):
155
            self.parse_range(param_range, name, **kwargs)
156
        return tuple(params)
157
158
    @staticmethod
159
    def parse_default_value(value: Union[str, float]) -> Union[str, float]:
160
        if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
161
            return value
162
        message = (
163
            'Value for default_pad_value must be "minimum", "otsu", "mean"'
164
            ' or a number'
165
        )
166
        raise ValueError(message)
167
168
    def get_params(
169
            self,
170
            scales,
171
            degrees,
172
            translation,
173
            isotropic,
174
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
175
        scaling_params = self.sample_uniform_sextet(scales)
176
        if isotropic:
177
            scaling_params.fill_(scaling_params[0])
178
        rotation_params = self.sample_uniform_sextet(degrees)
179
        translation_params = self.sample_uniform_sextet(translation)
180
        return scaling_params, rotation_params, translation_params
181
182
    @staticmethod
183
    def get_scaling_transform(
184
            scaling_params: List[float],
185
            center_lps: Optional[TypeTripletFloat] = None,
186
            ) -> sitk.ScaleTransform:
187
        # scaling_params are inverted so that they are more intuitive
188
        # For example, 1.5 means the objects look 1.5 times larger
189
        transform = sitk.ScaleTransform(3)
190
        scaling_params = 1 / np.array(scaling_params)
191
        transform.SetScale(scaling_params)
192
        if center_lps is not None:
193
            transform.SetCenter(center_lps)
194
        return transform
195
196
    @staticmethod
197
    def get_rotation_transform(
198
            degrees: List[float],
199
            translation: List[float],
200
            center_lps: Optional[TypeTripletFloat] = None,
201
            ) -> sitk.Euler3DTransform:
202
        transform = sitk.Euler3DTransform()
203
        radians = np.radians(degrees)
204
        transform.SetRotation(*radians)
205
        transform.SetTranslation(translation)
206
        if center_lps is not None:
207
            transform.SetCenter(center_lps)
208
        return transform
209
210
    def apply_transform(self, subject: Subject) -> Subject:
211
        subject.check_consistent_spatial_shape()
212
        scaling_params, rotation_params, translation_params = self.get_params(
213
            self.scales,
214
            self.degrees,
215
            self.translation,
216
            self.isotropic,
217
        )
218
        for image in self.get_images(subject):
219
            if image[TYPE] != INTENSITY:
220
                interpolation = Interpolation.NEAREST
221
            else:
222
                interpolation = self.interpolation
223
224
            if image.is_2d():
225
                scaling_params[2] = 1
226
                rotation_params[:-1] = 0
227
228
            if self.use_image_center:
229
                center = image.get_center(lps=True)
230
            else:
231
                center = None
232
233
            transformed_tensors = []
234
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
235
                transformed_tensor = self.apply_affine_transform(
236
                    tensor,
237
                    image[AFFINE],
238
                    scaling_params.tolist(),
239
                    rotation_params.tolist(),
240
                    translation_params.tolist(),
241
                    interpolation,
242
                    center_lps=center,
243
                )
244
                transformed_tensors.append(transformed_tensor)
245
            image[DATA] = torch.stack(transformed_tensors)
246
        random_parameters_dict = {
247
            'scaling': scaling_params,
248
            'rotation': rotation_params,
249
            'translation': translation_params,
250
        }
251
        subject.add_transform(self, random_parameters_dict)
252
        return subject
253
254
    def apply_affine_transform(
255
            self,
256
            tensor: torch.Tensor,
257
            affine: np.ndarray,
258
            scaling_params: List[float],
259
            rotation_params: List[float],
260
            translation_params: List[float],
261
            interpolation: Interpolation,
262
            center_lps: Optional[TypeTripletFloat] = None,
263
            ) -> torch.Tensor:
264
        assert tensor.ndim == 3
265
266
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
267
        floating = reference = image
268
269
        scaling_transform = self.get_scaling_transform(
270
            scaling_params,
271
            center_lps=center_lps,
272
        )
273
        rotation_transform = self.get_rotation_transform(
274
            rotation_params,
275
            translation_params,
276
            center_lps=center_lps,
277
        )
278
279
        sitk_major_version = get_major_sitk_version()
280
        if sitk_major_version == 1:
281
            transform = sitk.Transform(3, sitk.sitkComposite)
282
            transform.AddTransform(scaling_transform)
283
            transform.AddTransform(rotation_transform)
284
        elif sitk_major_version == 2:
285
            transforms = [scaling_transform, rotation_transform]
286
            transform = sitk.CompositeTransform(transforms)
287
288
        if self.default_pad_value == 'minimum':
289
            default_value = tensor.min().item()
290
        elif self.default_pad_value == 'mean':
291
            default_value = get_borders_mean(image, filter_otsu=False)
292
        elif self.default_pad_value == 'otsu':
293
            default_value = get_borders_mean(image, filter_otsu=True)
294
        else:
295
            default_value = self.default_pad_value
296
297
        resampler = sitk.ResampleImageFilter()
298
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
299
        resampler.SetReferenceImage(reference)
300
        resampler.SetDefaultPixelValue(float(default_value))
301
        resampler.SetOutputPixelType(sitk.sitkFloat32)
302
        resampler.SetTransform(transform)
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
303
        resampled = resampler.Execute(floating)
304
305
        np_array = sitk.GetArrayFromImage(resampled)
306
        np_array = np_array.transpose()  # ITK to NumPy
307
        tensor = torch.from_numpy(np_array)
308
        return tensor
309
310
# flake8: noqa: E201, E203, E243
311
def get_borders_mean(image, filter_otsu=True):
312
    # pylint: disable=bad-whitespace
313
    array = sitk.GetArrayViewFromImage(image)
314
    borders_tuple = (
315
        array[ 0,  :,  :],
316
        array[-1,  :,  :],
317
        array[ :,  0,  :],
318
        array[ :, -1,  :],
319
        array[ :,  :,  0],
320
        array[ :,  :, -1],
321
    )
322
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
323
    if not filter_otsu:
324
        return borders_flat.mean()
325
    borders_reshaped = borders_flat.reshape(1, 1, -1)
326
    borders_image = sitk.GetImageFromArray(borders_reshaped)
327
    otsu = sitk.OtsuThresholdImageFilter()
328
    otsu.Execute(borders_image)
329
    threshold = otsu.GetThreshold()
330
    values = borders_flat[borders_flat < threshold]
331
    if values.any():
332
        default_value = values.mean()
333
    else:
334
        default_value = borders_flat.mean()
335
    return default_value
336