Passed
Push — master ( a5fd0f...582603 )
by Fernando
01:13
created

RandomAffine.apply_transform()   B

Complexity

Conditions 5

Size

Total Lines 33
Code Lines 26

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 26
nop 2
dl 0
loc 33
rs 8.7893
c 0
b 0
f 0
1
from numbers import Number
2
from typing import Tuple, Optional, List, Union
3
import torch
4
import numpy as np
5
import SimpleITK as sitk
6
from ....data.subject import Subject
7
from ....torchio import (
8
    LABEL,
9
    DATA,
10
    AFFINE,
11
    TYPE,
12
    TypeRangeFloat,
13
    TypeTripletFloat,
14
)
15
from .. import Interpolation, get_sitk_interpolator
16
from .. import RandomTransform
17
18
19
class RandomAffine(RandomTransform):
20
    r"""Random affine transformation.
21
22
    Args:
23
        scales: Tuple :math:`(a, b)` defining the scaling
24
            magnitude. The scaling values along each dimension are
25
            :math:`(s_1, s_2, s_3)`, where :math:`s_i \sim \mathcal{U}(a, b)`.
26
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
27
            making the objects inside look twice as small while preserving
28
            the physical size and position of the image.
29
        degrees: Tuple :math:`(a, b)` defining the rotation range in degrees.
30
            The rotation angles around each axis are
31
            :math:`(\theta_1, \theta_2, \theta_3)`,
32
            where :math:`\theta_i \sim \mathcal{U}(a, b)`.
33
            If only one value :math:`d` is provided,
34
            :math:`\theta_i \sim \mathcal{U}(-d, d)`.
35
        isotropic: If ``True``, the scaling factor along all dimensions is the
36
            same, i.e. :math:`s_1 = s_2 = s_3`.
37
        center: If ``'image'``, rotations and scaling will be performed around
38
            the image center. If ``'origin'``, rotations and scaling will be
39
            performed around the origin in world coordinates.
40
        default_pad_value: As the image is rotated, some values near the
41
            borders will be undefined.
42
            If ``'minimum'``, the fill value will be the image minimum.
43
            If ``'mean'``, the fill value is the mean of the border values.
44
            If ``'otsu'``, the fill value is the mean of the values at the
45
            border that lie under an
46
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
47
        image_interpolation: See :ref:`Interpolation`.
48
        p: Probability that this transform will be applied.
49
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
50
51
    Example:
52
        >>> from torchio.transforms import RandomAffine, Interpolation
53
        >>> sample = images_dataset[0]  # instance of torchio.ImagesDataset
54
        >>> transform = RandomAffine(
55
        ...     scales=(0.9, 1.2),
56
        ...     degrees=(10),
57
        ...     isotropic=False,
58
        ...     default_pad_value='otsu',
59
        ...     image_interpolation='bspline',
60
        ... )
61
        >>> transformed = transform(sample)
62
63
    From the command line::
64
65
        $ torchio-transform t1.nii.gz RandomAffine --kwargs "degrees=30 default_pad_value=minimum" --seed 42 affine_min.nii.gz
66
67
    """
68
    def __init__(
69
            self,
70
            scales: Tuple[float, float] = (0.9, 1.1),
71
            degrees: TypeRangeFloat = 10,
72
            isotropic: bool = False,
73
            center: str = 'image',
74
            default_pad_value: Union[str, float] = 'otsu',
75
            image_interpolation: str = 'linear',
76
            p: float = 1,
77
            seed: Optional[int] = None,
78
            ):
79
        super().__init__(p=p, seed=seed)
80
        self.scales = scales
81
        self.degrees = self.parse_degrees(degrees)
82
        self.isotropic = isotropic
83
        if center not in ('image', 'origin'):
84
            message = (
85
                'Center argument must be "image" or "origin",'
86
                f' not "{center}"'
87
            )
88
            raise ValueError(message)
89
        self.use_image_center = center == 'image'
90
        self.default_pad_value = self.parse_default_value(default_pad_value)
91
        self.interpolation = self.parse_interpolation(image_interpolation)
92
93
    @staticmethod
94
    def parse_default_value(value: Union[str, float]) -> Union[str, float]:
95
        if isinstance(value, Number) or value in ('minimum', 'otsu', 'mean'):
96
            return value
97
        message = (
98
            'Value for default_pad_value must be "minimum", "otsu", "mean"'
99
            ' or a number'
100
        )
101
        raise ValueError(message)
102
103
    @staticmethod
104
    def get_params(
105
            scales: Tuple[float, float],
106
            degrees: Tuple[float, float],
107
            isotropic: bool,
108
            ) -> Tuple[np.ndarray, np.ndarray]:
109
        scaling_params = torch.FloatTensor(3).uniform_(*scales)
110
        if isotropic:
111
            scaling_params.fill_(scaling_params[0])
112
        rotation_params = torch.FloatTensor(3).uniform_(*degrees)
113
        return scaling_params.numpy(), rotation_params.numpy()
114
115
    @staticmethod
116
    def get_scaling_transform(
117
            scaling_params: List[float],
118
            center_lps: Optional[TypeTripletFloat] = None,
119
            ) -> sitk.ScaleTransform:
120
        # scaling_params are inverted so that they are more intuitive
121
        # For example, 1.5 means the objects look 1.5 times larger
122
        transform = sitk.ScaleTransform(3)
123
        scaling_params = 1 / np.array(scaling_params)
124
        transform.SetScale(scaling_params)
125
        if center_lps is not None:
126
            transform.SetCenter(center_lps)
127
        return transform
128
129
    @staticmethod
130
    def get_rotation_transform(
131
            degrees: List[float],
132
            center_lps: Optional[TypeTripletFloat] = None,
133
            ) -> sitk.Euler3DTransform:
134
        transform = sitk.Euler3DTransform()
135
        radians = np.radians(degrees)
136
        transform.SetRotation(*radians)
137
        if center_lps is not None:
138
            transform.SetCenter(center_lps)
139
        return transform
140
141
    def apply_transform(self, sample: Subject) -> dict:
142
        sample.check_consistent_shape()
143
        scaling_params, rotation_params = self.get_params(
144
            self.scales, self.degrees, self.isotropic)
145
        for image in sample.get_images(intensity_only=False):
146
            if image[TYPE] == LABEL:
147
                interpolation = Interpolation.NEAREST
148
            else:
149
                interpolation = self.interpolation
150
151
            if image.is_2d():
152
                scaling_params[0] = 1
153
                rotation_params[-2:] = 0
154
155
            if self.use_image_center:
156
                center = image.get_center(lps=True)
157
            else:
158
                center = None
159
160
            image[DATA] = self.apply_affine_transform(
161
                image[DATA],
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
162
                image[AFFINE],
163
                scaling_params.tolist(),
164
                rotation_params.tolist(),
165
                interpolation,
166
                center_lps=center,
167
            )
168
        random_parameters_dict = {
169
            'scaling': scaling_params,
170
            'rotation': rotation_params,
171
        }
172
        sample.add_transform(self, random_parameters_dict)
173
        return sample
174
175
    def apply_affine_transform(
176
            self,
177
            tensor: torch.Tensor,
178
            affine: np.ndarray,
179
            scaling_params: List[float],
180
            rotation_params: List[float],
181
            interpolation: Interpolation,
182
            center_lps: Optional[TypeTripletFloat] = None,
183
            ) -> torch.Tensor:
184
        assert tensor.ndim == 4
185
        assert len(tensor) == 1
186
187
        image = self.nib_to_sitk(tensor[0], affine)
188
        floating = reference = image
189
190
        scaling_transform = self.get_scaling_transform(
191
            scaling_params,
192
            center_lps=center_lps,
193
        )
194
        rotation_transform = self.get_rotation_transform(
195
            rotation_params,
196
            center_lps=center_lps,
197
        )
198
        transform = sitk.Transform(3, sitk.sitkComposite)
199
        transform.AddTransform(scaling_transform)
200
        transform.AddTransform(rotation_transform)
201
202
        if self.default_pad_value == 'minimum':
203
            default_value = tensor.min().item()
204
        elif self.default_pad_value == 'mean':
205
            default_value = get_borders_mean(image, filter_otsu=False)
206
        elif self.default_pad_value == 'otsu':
207
            default_value = get_borders_mean(image, filter_otsu=True)
208
        else:
209
            default_value = self.default_pad_value
210
211
        resampler = sitk.ResampleImageFilter()
212
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
213
        resampler.SetReferenceImage(reference)
214
        resampler.SetDefaultPixelValue(float(default_value))
215
        resampler.SetOutputPixelType(sitk.sitkFloat32)
216
        resampler.SetTransform(transform)
217
        resampled = resampler.Execute(floating)
218
219
        np_array = sitk.GetArrayFromImage(resampled)
220
        np_array = np_array.transpose()  # ITK to NumPy
221
        tensor[0] = torch.from_numpy(np_array)
222
        return tensor
223
224
225
def get_borders_mean(image, filter_otsu=True):
226
    # pylint: disable=bad-whitespace
227
    array = sitk.GetArrayViewFromImage(image)
228
    borders_tuple = (
229
        array[ 0,  :,  :],
230
        array[-1,  :,  :],
231
        array[ :,  0,  :],
232
        array[ :, -1,  :],
233
        array[ :,  :,  0],
234
        array[ :,  :, -1],
235
    )
236
    borders_flat = np.hstack([border.ravel() for border in borders_tuple])
237
    if not filter_otsu:
238
        return borders_flat.mean()
239
    borders_reshaped = borders_flat.reshape(1, 1, -1)
240
    borders_image = sitk.GetImageFromArray(borders_reshaped)
241
    otsu = sitk.OtsuThresholdImageFilter()
242
    otsu.Execute(borders_image)
243
    threshold = otsu.GetThreshold()
244
    values = borders_flat[borders_flat < threshold]
245
    if values.any():
246
        default_value = values.mean()
247
    else:
248
        default_value = borders_flat.mean()
249
    return default_value
250