Passed
Push — master ( a5fd0f...582603 )
by Fernando
01:13
created

RandomAffine.__init__()   A

Complexity

Conditions 2

Size

Total Lines 24
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 21
nop 9
dl 0
loc 24
rs 9.376
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from numbers import Number
2
from typing import Tuple, Optional, List, Union
3
import torch
4
import numpy as np
5
import SimpleITK as sitk
6
from ....data.subject import Subject
7
from ....torchio import (
8
    LABEL,
9
    DATA,
10
    AFFINE,
11
    TYPE,
12
    TypeRangeFloat,
13
    TypeTripletFloat,
14
)
15
from .. import Interpolation, get_sitk_interpolator
16
from .. import RandomTransform
17
18
19
class RandomAffine(RandomTransform):
20
    r"""Random affine transformation.
21
22
    Args:
23
        scales: Tuple :math:`(a, b)` defining the scaling
24
            magnitude. The scaling values along each dimension are
25
            :math:`(s_1, s_2, s_3)`, where :math:`s_i \sim \mathcal{U}(a, b)`.
26
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
27
            making the objects inside look twice as small while preserving
28
            the physical size and position of the image.
29
        degrees: Tuple :math:`(a, b)` defining the rotation range in degrees.
30
            The rotation angles around each axis are
31
            :math:`(\theta_1, \theta_2, \theta_3)`,
32
            where :math:`\theta_i \sim \mathcal{U}(a, b)`.
33
            If only one value :math:`d` is provided,
34
            :math:`\theta_i \sim \mathcal{U}(-d, d)`.
35
        isotropic: If ``True``, the scaling factor along all dimensions is the
36
            same, i.e. :math:`s_1 = s_2 = s_3`.
37
        center: If ``'image'``, rotations and scaling will be performed around
38
            the image center. If ``'origin'``, rotations and scaling will be
39
            performed around the origin in world coordinates.
40
        default_pad_value: As the image is rotated, some values near the
41
            borders will be undefined.
42
            If ``'minimum'``, the fill value will be the image minimum.
43
            If ``'mean'``, the fill value is the mean of the border values.
44
            If ``'otsu'``, the fill value is the mean of the values at the
45
            border that lie under an
46
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
47
        image_interpolation: See :ref:`Interpolation`.
48
        p: Probability that this transform will be applied.
49
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
50
51
    Example:
52
        >>> from torchio.transforms import RandomAffine, Interpolation
53
        >>> sample = images_dataset[0]  # instance of torchio.ImagesDataset
54
        >>> transform = RandomAffine(
55
        ...     scales=(0.9, 1.2),
56
        ...     degrees=(10),
57
        ...     isotropic=False,
58
        ...     default_pad_value='otsu',
59
        ...     image_interpolation='bspline',
60
        ... )
61
        >>> transformed = transform(sample)
62
63
    From the command line::
64
65
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "degrees=30 default_pad_value=minimum" --seed 42 affine_min.nii.gz
66
67
    """
68
    def __init__(
69
            self,
70
            scales: Tuple[float, float] = (0.9, 1.1),
71
            degrees: TypeRangeFloat = 10,
72
            isotropic: bool = False,
73
            center: str = 'image',
74
            default_pad_value: Union[str, float] = 'otsu',
75
            image_interpolation: str = 'linear',
76
            p: float = 1,
77
            seed: Optional[int] = None,
78
            ):
79
        super().__init__(p=p, seed=seed)
80
        self.scales = scales
81
        self.degrees = self.parse_degrees(degrees)
82
        self.isotropic = isotropic
83
        if center not in ('image', 'origin'):
84
            message = (
85
                'Center argument must be "image" or "origin",'
86
                f' not "{center}"'
87
            )
88
            raise ValueError(message)
89
        self.use_image_center = center == 'image'
90
        self.default_pad_value = self.parse_default_value(default_pad_value)
91
        self.interpolation = self.parse_interpolation(image_interpolation)
92
93
    @staticmethod
94
    def parse_default_value(value: Union[str, float]) -> Union[str, float]:
95
        if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
96
            return value
97
        message = (
98
            'Value for default_pad_value must be "minimum", "otsu", "mean"'
99
            ' or a number'
100
        )
101
        raise ValueError(message)
102
103
    @staticmethod
104
    def get_params(
105
            scales: Tuple[float, float],
106
            degrees: Tuple[float, float],
107
            isotropic: bool,
108
            ) -> Tuple[np.ndarray, np.ndarray]:
109
        scaling_params = torch.FloatTensor(3).uniform_(*scales)
110
        if isotropic:
111
            scaling_params.fill_(scaling_params[0])
112
        rotation_params = torch.FloatTensor(3).uniform_(*degrees)
113
        return scaling_params.numpy(), rotation_params.numpy()
114
115
    @staticmethod
116
    def get_scaling_transform(
117
            scaling_params: List[float],
118
            center_lps: Optional[TypeTripletFloat] = None,
119
            ) -> sitk.ScaleTransform:
120
        # scaling_params are inverted so that they are more intuitive
121
        # For example, 1.5 means the objects look 1.5 times larger
122
        transform = sitk.ScaleTransform(3)
123
        scaling_params = 1 / np.array(scaling_params)
124
        transform.SetScale(scaling_params)
125
        if center_lps is not None:
126
            transform.SetCenter(center_lps)
127
        return transform
128
129
    @staticmethod
130
    def get_rotation_transform(
131
            degrees: List[float],
132
            center_lps: Optional[TypeTripletFloat] = None,
133
            ) -> sitk.Euler3DTransform:
134
        transform = sitk.Euler3DTransform()
135
        radians = np.radians(degrees)
136
        transform.SetRotation(*radians)
137
        if center_lps is not None:
138
            transform.SetCenter(center_lps)
139
        return transform
140
141
    def apply_transform(self, sample: Subject) -> dict:
142
        sample.check_consistent_shape()
143
        scaling_params, rotation_params = self.get_params(
144
            self.scales, self.degrees, self.isotropic)
145
        for image in sample.get_images(intensity_only=False):
146
            if image[TYPE] == LABEL:
147
                interpolation = Interpolation.NEAREST
148
            else:
149
                interpolation = self.interpolation
150
151
            if image.is_2d():
152
                scaling_params[0] = 1
153
                rotation_params[-2:] = 0
154
155
            if self.use_image_center:
156
                center = image.get_center(lps=True)
157
            else:
158
                center = None
159
160
            image[DATA] = self.apply_affine_transform(
161
                image[DATA],
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
162
                image[AFFINE],
163
                scaling_params.tolist(),
164
                rotation_params.tolist(),
165
                interpolation,
166
                center_lps=center,
167
            )
168
        random_parameters_dict = {
169
            'scaling': scaling_params,
170
            'rotation': rotation_params,
171
        }
172
        sample.add_transform(self, random_parameters_dict)
173
        return sample
174
175
    def apply_affine_transform(
176
            self,
177
            tensor: torch.Tensor,
178
            affine: np.ndarray,
179
            scaling_params: List[float],
180
            rotation_params: List[float],
181
            interpolation: Interpolation,
182
            center_lps: Optional[TypeTripletFloat] = None,
183
            ) -> torch.Tensor:
184
        assert tensor.ndim == 4
185
        assert len(tensor) == 1
186
187
        image = self.nib_to_sitk(tensor[0], affine)
188
        floating = reference = image
189
190
        scaling_transform = self.get_scaling_transform(
191
            scaling_params,
192
            center_lps=center_lps,
193
        )
194
        rotation_transform = self.get_rotation_transform(
195
            rotation_params,
196
            center_lps=center_lps,
197
        )
198
        transform = sitk.Transform(3, sitk.sitkComposite)
199
        transform.AddTransform(scaling_transform)
200
        transform.AddTransform(rotation_transform)
201
202
        if self.default_pad_value == 'minimum':
203
            default_value = tensor.min().item()
204
        elif self.default_pad_value == 'mean':
205
            default_value = get_borders_mean(image, filter_otsu=False)
206
        elif self.default_pad_value == 'otsu':
207
            default_value = get_borders_mean(image, filter_otsu=True)
208
        else:
209
            default_value = self.default_pad_value
210
211
        resampler = sitk.ResampleImageFilter()
212
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
213
        resampler.SetReferenceImage(reference)
214
        resampler.SetDefaultPixelValue(float(default_value))
215
        resampler.SetOutputPixelType(sitk.sitkFloat32)
216
        resampler.SetTransform(transform)
217
        resampled = resampler.Execute(floating)
218
219
        np_array = sitk.GetArrayFromImage(resampled)
220
        np_array = np_array.transpose()  # ITK to NumPy
221
        tensor[0] = torch.from_numpy(np_array)
222
        return tensor
223
224
225
def get_borders_mean(image, filter_otsu=True):
226
    # pylint: disable=bad-whitespace
227
    array = sitk.GetArrayViewFromImage(image)
228
    borders_tuple = (
229
        array[ 0,  :,  :],
230
        array[-1,  :,  :],
231
        array[ :,  0,  :],
232
        array[ :, -1,  :],
233
        array[ :,  :,  0],
234
        array[ :,  :, -1],
235
    )
236
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
237
    if not filter_otsu:
238
        return borders_flat.mean()
239
    borders_reshaped = borders_flat.reshape(1, 1, -1)
240
    borders_image = sitk.GetImageFromArray(borders_reshaped)
241
    otsu = sitk.OtsuThresholdImageFilter()
242
    otsu.Execute(borders_image)
243
    threshold = otsu.GetThreshold()
244
    values = borders_flat[borders_flat < threshold]
245
    if values.any():
246
        default_value = values.mean()
247
    else:
248
        default_value = borders_flat.mean()
249
    return default_value
250