Passed
Pull Request — master (#346)
by Fernando
05:26
created

RandomAffine.to_range()   A

Complexity

Conditions 2

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 5
nop 2
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
from numbers import Number
2
from typing import Tuple, Optional, List, Union
3
import torch
4
import numpy as np
5
import SimpleITK as sitk
6
from ....data.subject import Subject
7
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
8
from ....torchio import (
9
    INTENSITY,
10
    DATA,
11
    AFFINE,
12
    TYPE,
13
    TypeRangeFloat,
14
    TypeSextetFloat,
15
    TypeTripletFloat,
16
)
17
from ... import SpatialTransform
18
from .. import Interpolation, get_sitk_interpolator
19
from .. import RandomTransform
20
21
22
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
23
24
25
class RandomAffine(RandomTransform, SpatialTransform):
26
    r"""Random affine transformation.
27
28
    Args:
29
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
30
            scaling ranges.
31
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
32
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
33
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
34
            making the objects inside look twice as small while preserving
35
            the physical size and position of the image bounds.
36
            If only one value :math:`x` is provided,
37
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
38
            If two values :math:`(a, b)` are provided,
39
            then :math:`s_i \sim \mathcal{U}(a, b)`.
40
            If three values :math:`(x_1, x_2, x_3)` are provided,
41
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
42
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
43
            rotation ranges in degrees.
44
            Rotation angles around each axis are
45
            :math:`(\theta_1, \theta_2, \theta_3)`,
46
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
47
            If only one value :math:`x` is provided,
48
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
49
            If two values :math:`(a, b)` are provided,
50
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
51
            If three values :math:`(x_1, x_2, x_3)` are provided,
52
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
53
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
54
            translation ranges in mm.
55
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
56
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
57
            If only one value :math:`x` is provided,
58
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
59
            If two values :math:`(a, b)` are provided,
60
            then :math:`t_i \sim \mathcal{U}(a, b)`.
61
            If three values :math:`(x_1, x_2, x_3)` are provided,
62
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
63
        isotropic: If ``True``, the scaling factor along all dimensions is the
64
            same, i.e. :math:`s_1 = s_2 = s_3`.
65
        center: If ``'image'``, rotations and scaling will be performed around
66
            the image center. If ``'origin'``, rotations and scaling will be
67
            performed around the origin in world coordinates.
68
        default_pad_value: As the image is rotated, some values near the
69
            borders will be undefined.
70
            If ``'minimum'``, the fill value will be the image minimum.
71
            If ``'mean'``, the fill value is the mean of the border values.
72
            If ``'otsu'``, the fill value is the mean of the values at the
73
            border that lie under an
74
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
75
            If it is a number, that value will be used.
76
        image_interpolation: See :ref:`Interpolation`.
77
        p: Probability that this transform will be applied.
78
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
79
        keys: See :py:class:`~torchio.transforms.Transform`.
80
81
    Example:
82
        >>> import torchio as tio
83
        >>> subject = tio.datasets.Colin27()
84
        >>> transform = tio.RandomAffine(
85
        ...     scales=(0.9, 1.2),
86
        ...     degrees=(10),
87
        ...     isotropic=False,
88
        ...     default_pad_value='otsu',
89
        ...     image_interpolation='bspline',
90
        ... )
91
        >>> transformed = transform(subject)
92
93
    From the command line::
94
95
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "degrees=30 default_pad_value=minimum" --seed 42 affine_min.nii.gz
96
97
    """
98
    def __init__(
99
            self,
100
            scales: TypeOneToSixFloat = 0.1,
101
            degrees: TypeOneToSixFloat = 10,
102
            translation: TypeOneToSixFloat = 0,
103
            isotropic: bool = False,
104
            center: str = 'image',
105
            default_pad_value: Union[str, float] = 'otsu',
106
            image_interpolation: str = 'linear',
107
            p: float = 1,
108
            seed: Optional[int] = None,
109
            keys: Optional[List[str]] = None,
110
            ):
111
        super().__init__(p=p, seed=seed, keys=keys)
112
        self.isotropic = isotropic
113
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
114
        self.degrees = self.parse_params(degrees, 0, 'degrees')
115
        self.translation = self.parse_params(translation, 0, 'translation')
116
        if center not in ('image', 'origin'):
117
            message = (
118
                'Center argument must be "image" or "origin",'
119
                f' not "{center}"'
120
            )
121
            raise ValueError(message)
122
        self.use_image_center = center == 'image'
123
        self.default_pad_value = self.parse_default_value(default_pad_value)
124
        self.interpolation = self.parse_interpolation(image_interpolation)
125
126
    @staticmethod
127
    def to_range(n, around):
128
        if around is None:
129
            return 0, n
130
        else:
131
            return around - n, around + n
132
133
    def parse_params(self, params, around, name, **kwargs):
134
        params = to_tuple(params)
135
        length = len(params)
136
        if name == 'scales' and self.isotropic and length in (3, 6):
137
            message = (
138
                f'If "isotropic" is True, the value for "{name}" must have'
139
                f' length 1 or 2, but "{params}" was passed'
140
            )
141
            raise ValueError(message)
142
        if len(params) == 1 or len(params) == 2:  # d or (a, b)
143
            params *= 3  # (d, d, d) or (a, b, a, b, a, b)
144
        if len(params) == 3:  # (a, b, c)
145
            items = [self.to_range(n, around) for n in params]
146
            # (-a, a, -b, b, -c, c) or (1-a, 1+a, 1-b, 1+b, 1-c, 1+c)
147
            params = [n for prange in items for n in prange]
148
        assert len(params) == 6
149
        for param_range in zip(params[::2], params[1::2]):
150
            self.parse_range(param_range, name, **kwargs)
151
        return params
152
153
    @staticmethod
154
    def parse_default_value(value: Union[str, float]) -> Union[str, float]:
155
        if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
156
            return value
157
        message = (
158
            'Value for default_pad_value must be "minimum", "otsu", "mean"'
159
            ' or a number'
160
        )
161
        raise ValueError(message)
162
163
    def get_params(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
164
        scaling_params = self.sample_uniform_sextet(self.scales)
165
        if self.isotropic:
166
            scaling_params.fill_(scaling_params[0])
167
        rotation_params = self.sample_uniform_sextet(self.degrees)
168
        translation_params = self.sample_uniform_sextet(self.translation)
169
        return scaling_params, rotation_params, translation_params
170
171
    @staticmethod
172
    def get_scaling_transform(
173
            scaling_params: List[float],
174
            center_lps: Optional[TypeTripletFloat] = None,
175
            ) -> sitk.ScaleTransform:
176
        # scaling_params are inverted so that they are more intuitive
177
        # For example, 1.5 means the objects look 1.5 times larger
178
        transform = sitk.ScaleTransform(3)
179
        scaling_params = 1 / np.array(scaling_params)
180
        transform.SetScale(scaling_params)
181
        if center_lps is not None:
182
            transform.SetCenter(center_lps)
183
        return transform
184
185
    @staticmethod
186
    def get_rotation_transform(
187
            degrees: List[float],
188
            translation: List[float],
189
            center_lps: Optional[TypeTripletFloat] = None,
190
            ) -> sitk.Euler3DTransform:
191
        transform = sitk.Euler3DTransform()
192
        radians = np.radians(degrees)
193
        transform.SetRotation(*radians)
194
        transform.SetTranslation(translation)
195
        if center_lps is not None:
196
            transform.SetCenter(center_lps)
197
        return transform
198
199
    def apply_transform(self, subject: Subject) -> Subject:
200
        subject.check_consistent_spatial_shape()
201
        scaling_params, rotation_params, translation_params = self.get_params(
202
            self.scales,
203
            self.degrees,
204
            self.translation,
205
            self.isotropic,
206
        )
207
        for image in self.get_images(subject):
208
            if image[TYPE] != INTENSITY:
209
                interpolation = Interpolation.NEAREST
210
            else:
211
                interpolation = self.interpolation
212
213
            if image.is_2d():
214
                scaling_params[2] = 1
215
                rotation_params[:-1] = 0
216
217
            if self.use_image_center:
218
                center = image.get_center(lps=True)
219
            else:
220
                center = None
221
222
            transformed_tensors = []
223
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
224
                transformed_tensor = self.apply_affine_transform(
225
                    tensor,
226
                    image[AFFINE],
227
                    scaling_params.tolist(),
228
                    rotation_params.tolist(),
229
                    translation_params.tolist(),
230
                    interpolation,
231
                    center_lps=center,
232
                )
233
                transformed_tensors.append(transformed_tensor)
234
            image[DATA] = torch.stack(transformed_tensors)
235
        random_parameters_dict = {
236
            'scaling': scaling_params,
237
            'rotation': rotation_params,
238
            'translation': translation_params,
239
        }
240
        subject.add_transform(self, random_parameters_dict)
241
        return subject
242
243
    def apply_affine_transform(
244
            self,
245
            tensor: torch.Tensor,
246
            affine: np.ndarray,
247
            scaling_params: List[float],
248
            rotation_params: List[float],
249
            translation_params: List[float],
250
            interpolation: Interpolation,
251
            center_lps: Optional[TypeTripletFloat] = None,
252
            ) -> torch.Tensor:
253
        assert tensor.ndim == 3
254
255
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
256
        floating = reference = image
257
258
        scaling_transform = self.get_scaling_transform(
259
            scaling_params,
260
            center_lps=center_lps,
261
        )
262
        rotation_transform = self.get_rotation_transform(
263
            rotation_params,
264
            translation_params,
265
            center_lps=center_lps,
266
        )
267
268
        sitk_major_version = get_major_sitk_version()
269
        if sitk_major_version == 1:
270
            transform = sitk.Transform(3, sitk.sitkComposite)
271
            transform.AddTransform(scaling_transform)
272
            transform.AddTransform(rotation_transform)
273
        elif sitk_major_version == 2:
274
            transforms = [scaling_transform, rotation_transform]
275
            transform = sitk.CompositeTransform(transforms)
276
277
        if self.default_pad_value == 'minimum':
278
            default_value = tensor.min().item()
279
        elif self.default_pad_value == 'mean':
280
            default_value = get_borders_mean(image, filter_otsu=False)
281
        elif self.default_pad_value == 'otsu':
282
            default_value = get_borders_mean(image, filter_otsu=True)
283
        else:
284
            default_value = self.default_pad_value
285
286
        resampler = sitk.ResampleImageFilter()
287
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
288
        resampler.SetReferenceImage(reference)
289
        resampler.SetDefaultPixelValue(float(default_value))
290
        resampler.SetOutputPixelType(sitk.sitkFloat32)
291
        resampler.SetTransform(transform)
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
292
        resampled = resampler.Execute(floating)
293
294
        np_array = sitk.GetArrayFromImage(resampled)
295
        np_array = np_array.transpose()  # ITK to NumPy
296
        tensor = torch.from_numpy(np_array)
297
        return tensor
298
299
# flake8: noqa: E201, E203, E243
300
def get_borders_mean(image, filter_otsu=True):
301
    # pylint: disable=bad-whitespace
302
    array = sitk.GetArrayViewFromImage(image)
303
    borders_tuple = (
304
        array[ 0,  :,  :],
305
        array[-1,  :,  :],
306
        array[ :,  0,  :],
307
        array[ :, -1,  :],
308
        array[ :,  :,  0],
309
        array[ :,  :, -1],
310
    )
311
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
312
    if not filter_otsu:
313
        return borders_flat.mean()
314
    borders_reshaped = borders_flat.reshape(1, 1, -1)
315
    borders_image = sitk.GetImageFromArray(borders_reshaped)
316
    otsu = sitk.OtsuThresholdImageFilter()
317
    otsu.Execute(borders_image)
318
    threshold = otsu.GetThreshold()
319
    values = borders_flat[borders_flat < threshold]
320
    if values.any():
321
        default_value = values.mean()
322
    else:
323
        default_value = borders_flat.mean()
324
    return default_value
325