Passed
Pull Request — master (#346)
by Fernando
01:46
created

RandomAffine.parse_scales_isotropic()   A

Complexity

Conditions 3

Size

Total Lines 9
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 7
nop 2
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
from numbers import Number
2
from typing import Tuple, Optional, List, Union
3
import torch
4
import numpy as np
5
import SimpleITK as sitk
6
from ....data.subject import Subject
7
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
8
from ....torchio import (
9
    INTENSITY,
10
    DATA,
11
    AFFINE,
12
    TYPE,
13
    TypeRangeFloat,
14
    TypeSextetFloat,
15
    TypeTripletFloat,
16
)
17
from ... import SpatialTransform
18
from .. import Interpolation, get_sitk_interpolator
19
from .. import RandomTransform
20
21
22
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
23
24
25
class RandomAffine(RandomTransform, SpatialTransform):
26
    r"""Random affine transformation.
27
28
    Args:
29
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
30
            scaling ranges.
31
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
32
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
33
            If two values :math:`(a, b)` are provided,
34
            then :math:`s_i \sim \mathcal{U}(a, b)`.
35
            If only one value :math:`x` is provided,
36
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
37
            If three values :math:`(x_1, x_2, x_3)` are provided,
38
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
39
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
40
            making the objects inside look twice as small while preserving
41
            the physical size and position of the image bounds.
42
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
43
            rotation ranges in degrees.
44
            Rotation angles around each axis are
45
            :math:`(\theta_1, \theta_2, \theta_3)`,
46
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
47
            If two values :math:`(a, b)` are provided,
48
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
49
            If only one value :math:`x` is provided,
50
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
51
            If three values :math:`(x_1, x_2, x_3)` are provided,
52
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
53
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
54
            translation ranges in mm.
55
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
56
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
57
            If two values :math:`(a, b)` are provided,
58
            then :math:`t_i \sim \mathcal{U}(a, b)`.
59
            If only one value :math:`x` is provided,
60
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
61
            If three values :math:`(x_1, x_2, x_3)` are provided,
62
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
63
        isotropic: If ``True``, the scaling factor along all dimensions is the
64
            same, i.e. :math:`s_1 = s_2 = s_3`.
65
        center: If ``'image'``, rotations and scaling will be performed around
66
            the image center. If ``'origin'``, rotations and scaling will be
67
            performed around the origin in world coordinates.
68
        default_pad_value: As the image is rotated, some values near the
69
            borders will be undefined.
70
            If ``'minimum'``, the fill value will be the image minimum.
71
            If ``'mean'``, the fill value is the mean of the border values.
72
            If ``'otsu'``, the fill value is the mean of the values at the
73
            border that lie under an
74
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
75
            If it is a number, that value will be used.
76
        image_interpolation: See :ref:`Interpolation`.
77
        p: Probability that this transform will be applied.
78
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
79
        keys: See :py:class:`~torchio.transforms.Transform`.
80
81
    Example:
82
        >>> import torchio as tio
83
        >>> subject = tio.datasets.Colin27()
84
        >>> transform = tio.RandomAffine(
85
        ...     scales=(0.9, 1.2),
86
        ...     degrees=(10),
87
        ...     isotropic=False,
88
        ...     default_pad_value='otsu',
89
        ...     image_interpolation='bspline',
90
        ... )
91
        >>> transformed = transform(subject)
92
93
    From the command line::
94
95
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "degrees=30 default_pad_value=minimum" --seed 42 affine_min.nii.gz
96
97
    """
98
    def __init__(
99
            self,
100
            scales: TypeOneToSixFloat = 0.1,
101
            degrees: TypeOneToSixFloat = 10,
102
            translation: TypeOneToSixFloat = 0,
103
            isotropic: bool = False,
104
            center: str = 'image',
105
            default_pad_value: Union[str, float] = 'otsu',
106
            image_interpolation: str = 'linear',
107
            p: float = 1,
108
            seed: Optional[int] = None,
109
            keys: Optional[List[str]] = None,
110
            ):
111
        super().__init__(p=p, seed=seed, keys=keys)
112
        self.isotropic = isotropic
113
        self.parse_scales_isotropic(scales, isotropic)
114
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
115
        self.degrees = self.parse_params(degrees, 0, 'degrees')
116
        self.translation = self.parse_params(translation, 0, 'translation')
117
        if center not in ('image', 'origin'):
118
            message = (
119
                'Center argument must be "image" or "origin",'
120
                f' not "{center}"'
121
            )
122
            raise ValueError(message)
123
        self.use_image_center = center == 'image'
124
        self.default_pad_value = self.parse_default_value(default_pad_value)
125
        self.interpolation = self.parse_interpolation(image_interpolation)
126
127
    @staticmethod
128
    def parse_scales_isotropic(scales, isotropic):
129
        params = to_tuple(scales)
130
        if isotropic and len(scales) in (3, 6):
131
            message = (
132
                'If "isotropic" is True, the value for "scales" must have'
133
                f' length 1 or 2, but "{scales}" was passed'
134
            )
135
            raise ValueError(message)
136
137
    @staticmethod
138
    def parse_default_value(value: Union[str, float]) -> Union[str, float]:
139
        if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
140
            return value
141
        message = (
142
            'Value for default_pad_value must be "minimum", "otsu", "mean"'
143
            ' or a number'
144
        )
145
        raise ValueError(message)
146
147
    def get_params(
148
            self,
149
            scales: TypeSextetFloat,
150
            degrees: TypeSextetFloat,
151
            translation: TypeSextetFloat,
152
            isotropic: bool,
153
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
154
        scaling_params = self.sample_uniform_sextet(scales)
155
        if isotropic:
156
            scaling_params.fill_(scaling_params[0])
157
        rotation_params = self.sample_uniform_sextet(degrees)
158
        translation_params = self.sample_uniform_sextet(translation)
159
        return scaling_params, rotation_params, translation_params
160
161
    @staticmethod
162
    def get_scaling_transform(
163
            scaling_params: List[float],
164
            center_lps: Optional[TypeTripletFloat] = None,
165
            ) -> sitk.ScaleTransform:
166
        # scaling_params are inverted so that they are more intuitive
167
        # For example, 1.5 means the objects look 1.5 times larger
168
        transform = sitk.ScaleTransform(3)
169
        scaling_params = 1 / np.array(scaling_params)
170
        transform.SetScale(scaling_params)
171
        if center_lps is not None:
172
            transform.SetCenter(center_lps)
173
        return transform
174
175
    @staticmethod
176
    def get_rotation_transform(
177
            degrees: List[float],
178
            translation: List[float],
179
            center_lps: Optional[TypeTripletFloat] = None,
180
            ) -> sitk.Euler3DTransform:
181
        transform = sitk.Euler3DTransform()
182
        radians = np.radians(degrees)
183
        transform.SetRotation(*radians)
184
        transform.SetTranslation(translation)
185
        if center_lps is not None:
186
            transform.SetCenter(center_lps)
187
        return transform
188
189
    def apply_transform(self, subject: Subject) -> Subject:
190
        subject.check_consistent_spatial_shape()
191
        scaling_params, rotation_params, translation_params = self.get_params(
192
            self.scales,
193
            self.degrees,
194
            self.translation,
195
            self.isotropic,
196
        )
197
        for image in self.get_images(subject):
198
            if image[TYPE] != INTENSITY:
199
                interpolation = Interpolation.NEAREST
200
            else:
201
                interpolation = self.interpolation
202
203
            if image.is_2d():
204
                scaling_params[2] = 1
205
                rotation_params[:-1] = 0
206
207
            if self.use_image_center:
208
                center = image.get_center(lps=True)
209
            else:
210
                center = None
211
212
            transformed_tensors = []
213
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
214
                transformed_tensor = self.apply_affine_transform(
215
                    tensor,
216
                    image[AFFINE],
217
                    scaling_params.tolist(),
218
                    rotation_params.tolist(),
219
                    translation_params.tolist(),
220
                    interpolation,
221
                    center_lps=center,
222
                )
223
                transformed_tensors.append(transformed_tensor)
224
            image[DATA] = torch.stack(transformed_tensors)
225
        random_parameters_dict = {
226
            'scaling': scaling_params,
227
            'rotation': rotation_params,
228
            'translation': translation_params,
229
        }
230
        subject.add_transform(self, random_parameters_dict)
231
        return subject
232
233
    def apply_affine_transform(
234
            self,
235
            tensor: torch.Tensor,
236
            affine: np.ndarray,
237
            scaling_params: List[float],
238
            rotation_params: List[float],
239
            translation_params: List[float],
240
            interpolation: Interpolation,
241
            center_lps: Optional[TypeTripletFloat] = None,
242
            ) -> torch.Tensor:
243
        assert tensor.ndim == 3
244
245
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
246
        floating = reference = image
247
248
        scaling_transform = self.get_scaling_transform(
249
            scaling_params,
250
            center_lps=center_lps,
251
        )
252
        rotation_transform = self.get_rotation_transform(
253
            rotation_params,
254
            translation_params,
255
            center_lps=center_lps,
256
        )
257
258
        sitk_major_version = get_major_sitk_version()
259
        if sitk_major_version == 1:
260
            transform = sitk.Transform(3, sitk.sitkComposite)
261
            transform.AddTransform(scaling_transform)
262
            transform.AddTransform(rotation_transform)
263
        elif sitk_major_version == 2:
264
            transforms = [scaling_transform, rotation_transform]
265
            transform = sitk.CompositeTransform(transforms)
266
267
        if self.default_pad_value == 'minimum':
268
            default_value = tensor.min().item()
269
        elif self.default_pad_value == 'mean':
270
            default_value = get_borders_mean(image, filter_otsu=False)
271
        elif self.default_pad_value == 'otsu':
272
            default_value = get_borders_mean(image, filter_otsu=True)
273
        else:
274
            default_value = self.default_pad_value
275
276
        resampler = sitk.ResampleImageFilter()
277
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
278
        resampler.SetReferenceImage(reference)
279
        resampler.SetDefaultPixelValue(float(default_value))
280
        resampler.SetOutputPixelType(sitk.sitkFloat32)
281
        resampler.SetTransform(transform)
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
282
        resampled = resampler.Execute(floating)
283
284
        np_array = sitk.GetArrayFromImage(resampled)
285
        np_array = np_array.transpose()  # ITK to NumPy
286
        tensor = torch.from_numpy(np_array)
287
        return tensor
288
289
# flake8: noqa: E201, E203, E243
290
def get_borders_mean(image, filter_otsu=True):
291
    # pylint: disable=bad-whitespace
292
    array = sitk.GetArrayViewFromImage(image)
293
    borders_tuple = (
294
        array[ 0,  :,  :],
295
        array[-1,  :,  :],
296
        array[ :,  0,  :],
297
        array[ :, -1,  :],
298
        array[ :,  :,  0],
299
        array[ :,  :, -1],
300
    )
301
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
302
    if not filter_otsu:
303
        return borders_flat.mean()
304
    borders_reshaped = borders_flat.reshape(1, 1, -1)
305
    borders_image = sitk.GetImageFromArray(borders_reshaped)
306
    otsu = sitk.OtsuThresholdImageFilter()
307
    otsu.Execute(borders_image)
308
    threshold = otsu.GetThreshold()
309
    values = borders_flat[borders_flat < threshold]
310
    if values.any():
311
        default_value = values.mean()
312
    else:
313
        default_value = borders_flat.mean()
314
    return default_value
315