Passed
Pull Request — master (#346)
by Fernando
01:46
created

RandomAffine.__init__()   A

Complexity

Conditions 2

Size

Total Lines 28
Code Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 25
nop 11
dl 0
loc 28
rs 9.28
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from numbers import Number
2
from typing import Tuple, Optional, List, Union
3
import torch
4
import numpy as np
5
import SimpleITK as sitk
6
from ....data.subject import Subject
7
from ....utils import nib_to_sitk, get_major_sitk_version, to_tuple
8
from ....torchio import (
9
    INTENSITY,
10
    DATA,
11
    AFFINE,
12
    TYPE,
13
    TypeRangeFloat,
14
    TypeSextetFloat,
15
    TypeTripletFloat,
16
)
17
from ... import SpatialTransform
18
from .. import Interpolation, get_sitk_interpolator
19
from .. import RandomTransform
20
21
22
TypeOneToSixFloat = Union[TypeRangeFloat, TypeTripletFloat, TypeSextetFloat]
23
24
25
class RandomAffine(RandomTransform, SpatialTransform):
26
    r"""Random affine transformation.
27
28
    Args:
29
        scales: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
30
            scaling ranges.
31
            The scaling values along each dimension are :math:`(s_1, s_2, s_3)`,
32
            where :math:`s_i \sim \mathcal{U}(a_i, b_i)`.
33
            If two values :math:`(a, b)` are provided,
34
            then :math:`s_i \sim \mathcal{U}(a, b)`.
35
            If only one value :math:`x` is provided,
36
            then :math:`s_i \sim \mathcal{U}(1 - x, 1 + x)`.
37
            If three values :math:`(x_1, x_2, x_3)` are provided,
38
            then :math:`s_i \sim \mathcal{U}(1 - x_i, 1 + x_i)`.
39
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
40
            making the objects inside look twice as small while preserving
41
            the physical size and position of the image bounds.
42
        degrees: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
43
            rotation ranges in degrees.
44
            Rotation angles around each axis are
45
            :math:`(\theta_1, \theta_2, \theta_3)`,
46
            where :math:`\theta_i \sim \mathcal{U}(a_i, b_i)`.
47
            If two values :math:`(a, b)` are provided,
48
            then :math:`\theta_i \sim \mathcal{U}(a, b)`.
49
            If only one value :math:`x` is provided,
50
            then :math:`\theta_i \sim \mathcal{U}(-x, x)`.
51
            If three values :math:`(x_1, x_2, x_3)` are provided,
52
            then :math:`\theta_i \sim \mathcal{U}(-x_i, x_i)`.
53
        translation: Tuple :math:`(a_1, b_1, a_2, b_2, a_3, b_3)` defining the
54
            translation ranges in mm.
55
            Translation along each axis is :math:`(t_1, t_2, t_3)`,
56
            where :math:`t_i \sim \mathcal{U}(a_i, b_i)`.
57
            If two values :math:`(a, b)` are provided,
58
            then :math:`t_i \sim \mathcal{U}(a, b)`.
59
            If only one value :math:`x` is provided,
60
            then :math:`t_i \sim \mathcal{U}(-x, x)`.
61
            If three values :math:`(x_1, x_2, x_3)` are provided,
62
            then :math:`t_i \sim \mathcal{U}(-x_i, x_i)`.
63
        isotropic: If ``True``, the scaling factor along all dimensions is the
64
            same, i.e. :math:`s_1 = s_2 = s_3`.
65
        center: If ``'image'``, rotations and scaling will be performed around
66
            the image center. If ``'origin'``, rotations and scaling will be
67
            performed around the origin in world coordinates.
68
        default_pad_value: As the image is rotated, some values near the
69
            borders will be undefined.
70
            If ``'minimum'``, the fill value will be the image minimum.
71
            If ``'mean'``, the fill value is the mean of the border values.
72
            If ``'otsu'``, the fill value is the mean of the values at the
73
            border that lie under an
74
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
75
            If it is a number, that value will be used.
76
        image_interpolation: See :ref:`Interpolation`.
77
        p: Probability that this transform will be applied.
78
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
79
        keys: See :py:class:`~torchio.transforms.Transform`.
80
81
    Example:
82
        >>> import torchio as tio
83
        >>> subject = tio.datasets.Colin27()
84
        >>> transform = tio.RandomAffine(
85
        ...     scales=(0.9, 1.2),
86
        ...     degrees=(10),
87
        ...     isotropic=False,
88
        ...     default_pad_value='otsu',
89
        ...     image_interpolation='bspline',
90
        ... )
91
        >>> transformed = transform(subject)
92
93
    From the command line::
94
95
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "degrees=30 default_pad_value=minimum" --seed 42 affine_min.nii.gz
96
97
    """
98
    def __init__(
99
            self,
100
            scales: TypeOneToSixFloat = 0.1,
101
            degrees: TypeOneToSixFloat = 10,
102
            translation: TypeOneToSixFloat = 0,
103
            isotropic: bool = False,
104
            center: str = 'image',
105
            default_pad_value: Union[str, float] = 'otsu',
106
            image_interpolation: str = 'linear',
107
            p: float = 1,
108
            seed: Optional[int] = None,
109
            keys: Optional[List[str]] = None,
110
            ):
111
        super().__init__(p=p, seed=seed, keys=keys)
112
        self.isotropic = isotropic
113
        self.parse_scales_isotropic(scales, isotropic)
114
        self.scales = self.parse_params(scales, 1, 'scales', min_constraint=0)
115
        self.degrees = self.parse_params(degrees, 0, 'degrees')
116
        self.translation = self.parse_params(translation, 0, 'translation')
117
        if center not in ('image', 'origin'):
118
            message = (
119
                'Center argument must be "image" or "origin",'
120
                f' not "{center}"'
121
            )
122
            raise ValueError(message)
123
        self.use_image_center = center == 'image'
124
        self.default_pad_value = self.parse_default_value(default_pad_value)
125
        self.interpolation = self.parse_interpolation(image_interpolation)
126
127
    @staticmethod
128
    def parse_scales_isotropic(scales, isotropic):
129
        params = to_tuple(scales)
130
        if isotropic and len(scales) in (3, 6):
131
            message = (
132
                'If "isotropic" is True, the value for "scales" must have'
133
                f' length 1 or 2, but "{scales}" was passed'
134
            )
135
            raise ValueError(message)
136
137
    @staticmethod
138
    def parse_default_value(value: Union[str, float]) -> Union[str, float]:
139
        if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
140
            return value
141
        message = (
142
            'Value for default_pad_value must be "minimum", "otsu", "mean"'
143
            ' or a number'
144
        )
145
        raise ValueError(message)
146
147
    def get_params(
148
            self,
149
            scales: TypeSextetFloat,
150
            degrees: TypeSextetFloat,
151
            translation: TypeSextetFloat,
152
            isotropic: bool,
153
            ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
154
        scaling_params = self.sample_uniform_sextet(scales)
155
        if isotropic:
156
            scaling_params.fill_(scaling_params[0])
157
        rotation_params = self.sample_uniform_sextet(degrees)
158
        translation_params = self.sample_uniform_sextet(translation)
159
        return scaling_params, rotation_params, translation_params
160
161
    @staticmethod
162
    def get_scaling_transform(
163
            scaling_params: List[float],
164
            center_lps: Optional[TypeTripletFloat] = None,
165
            ) -> sitk.ScaleTransform:
166
        # scaling_params are inverted so that they are more intuitive
167
        # For example, 1.5 means the objects look 1.5 times larger
168
        transform = sitk.ScaleTransform(3)
169
        scaling_params = 1 / np.array(scaling_params)
170
        transform.SetScale(scaling_params)
171
        if center_lps is not None:
172
            transform.SetCenter(center_lps)
173
        return transform
174
175
    @staticmethod
176
    def get_rotation_transform(
177
            degrees: List[float],
178
            translation: List[float],
179
            center_lps: Optional[TypeTripletFloat] = None,
180
            ) -> sitk.Euler3DTransform:
181
        transform = sitk.Euler3DTransform()
182
        radians = np.radians(degrees)
183
        transform.SetRotation(*radians)
184
        transform.SetTranslation(translation)
185
        if center_lps is not None:
186
            transform.SetCenter(center_lps)
187
        return transform
188
189
    def apply_transform(self, subject: Subject) -> Subject:
190
        subject.check_consistent_spatial_shape()
191
        scaling_params, rotation_params, translation_params = self.get_params(
192
            self.scales,
193
            self.degrees,
194
            self.translation,
195
            self.isotropic,
196
        )
197
        for image in self.get_images(subject):
198
            if image[TYPE] != INTENSITY:
199
                interpolation = Interpolation.NEAREST
200
            else:
201
                interpolation = self.interpolation
202
203
            if image.is_2d():
204
                scaling_params[2] = 1
205
                rotation_params[:-1] = 0
206
207
            if self.use_image_center:
208
                center = image.get_center(lps=True)
209
            else:
210
                center = None
211
212
            transformed_tensors = []
213
            for tensor in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
214
                transformed_tensor = self.apply_affine_transform(
215
                    tensor,
216
                    image[AFFINE],
217
                    scaling_params.tolist(),
218
                    rotation_params.tolist(),
219
                    translation_params.tolist(),
220
                    interpolation,
221
                    center_lps=center,
222
                )
223
                transformed_tensors.append(transformed_tensor)
224
            image[DATA] = torch.stack(transformed_tensors)
225
        random_parameters_dict = {
226
            'scaling': scaling_params,
227
            'rotation': rotation_params,
228
            'translation': translation_params,
229
        }
230
        subject.add_transform(self, random_parameters_dict)
231
        return subject
232
233
    def apply_affine_transform(
234
            self,
235
            tensor: torch.Tensor,
236
            affine: np.ndarray,
237
            scaling_params: List[float],
238
            rotation_params: List[float],
239
            translation_params: List[float],
240
            interpolation: Interpolation,
241
            center_lps: Optional[TypeTripletFloat] = None,
242
            ) -> torch.Tensor:
243
        assert tensor.ndim == 3
244
245
        image = nib_to_sitk(tensor[np.newaxis], affine, force_3d=True)
246
        floating = reference = image
247
248
        scaling_transform = self.get_scaling_transform(
249
            scaling_params,
250
            center_lps=center_lps,
251
        )
252
        rotation_transform = self.get_rotation_transform(
253
            rotation_params,
254
            translation_params,
255
            center_lps=center_lps,
256
        )
257
258
        sitk_major_version = get_major_sitk_version()
259
        if sitk_major_version == 1:
260
            transform = sitk.Transform(3, sitk.sitkComposite)
261
            transform.AddTransform(scaling_transform)
262
            transform.AddTransform(rotation_transform)
263
        elif sitk_major_version == 2:
264
            transforms = [scaling_transform, rotation_transform]
265
            transform = sitk.CompositeTransform(transforms)
266
267
        if self.default_pad_value == 'minimum':
268
            default_value = tensor.min().item()
269
        elif self.default_pad_value == 'mean':
270
            default_value = get_borders_mean(image, filter_otsu=False)
271
        elif self.default_pad_value == 'otsu':
272
            default_value = get_borders_mean(image, filter_otsu=True)
273
        else:
274
            default_value = self.default_pad_value
275
276
        resampler = sitk.ResampleImageFilter()
277
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
278
        resampler.SetReferenceImage(reference)
279
        resampler.SetDefaultPixelValue(float(default_value))
280
        resampler.SetOutputPixelType(sitk.sitkFloat32)
281
        resampler.SetTransform(transform)
0 ignored issues
show
introduced by
The variable transform does not seem to be defined for all execution paths.
Loading history...
282
        resampled = resampler.Execute(floating)
283
284
        np_array = sitk.GetArrayFromImage(resampled)
285
        np_array = np_array.transpose()  # ITK to NumPy
286
        tensor = torch.from_numpy(np_array)
287
        return tensor
288
289
# flake8: noqa: E201, E203, E243
290
def get_borders_mean(image, filter_otsu=True):
291
    # pylint: disable=bad-whitespace
292
    array = sitk.GetArrayViewFromImage(image)
293
    borders_tuple = (
294
        array[ 0,  :,  :],
295
        array[-1,  :,  :],
296
        array[ :,  0,  :],
297
        array[ :, -1,  :],
298
        array[ :,  :,  0],
299
        array[ :,  :, -1],
300
    )
301
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
302
    if not filter_otsu:
303
        return borders_flat.mean()
304
    borders_reshaped = borders_flat.reshape(1, 1, -1)
305
    borders_image = sitk.GetImageFromArray(borders_reshaped)
306
    otsu = sitk.OtsuThresholdImageFilter()
307
    otsu.Execute(borders_image)
308
    threshold = otsu.GetThreshold()
309
    values = borders_flat[borders_flat < threshold]
310
    if values.any():
311
        default_value = values.mean()
312
    else:
313
        default_value = borders_flat.mean()
314
    return default_value
315