Passed
Push — master ( 2bbd92...263466 )
by Fernando
01:39
created

RandomAffine.parse_default_value()   A

Complexity

Conditions 3

Size

Total Lines 9
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 7
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
from numbers import Number
2
from typing import Tuple, Optional, List, Union
3
import torch
0 ignored issues
show
introduced by
Unable to import 'torch'
Loading history...
4
import numpy as np
0 ignored issues
show
introduced by
Unable to import 'numpy'
Loading history...
5
import SimpleITK as sitk
0 ignored issues
show
introduced by
Unable to import 'SimpleITK'
Loading history...
6
from ....utils import is_image_dict, check_consistent_shape
7
from ....torchio import LABEL, DATA, AFFINE, TYPE, TypeRangeFloat
8
from .. import Interpolation, get_sitk_interpolator
9
from .. import RandomTransform
10
11
12
class RandomAffine(RandomTransform):
13
    r"""Random affine transformation.
14
15
    Args:
16
        scales: Tuple :math:`(a, b)` defining the scaling
17
            magnitude. The scaling values along each dimension are
18
            :math:`(s_1, s_2, s_3)`, where :math:`s_i \sim \mathcal{U}(a, b)`.
19
            For example, using ``scales=(0.5, 0.5)`` will zoom out the image,
20
            making the objects inside look twice as small while preserving
21
            the physical size and position of the image.
22
        degrees: Tuple :math:`(a, b)` defining the rotation range in degrees.
23
            The rotation angles around each axis are
24
            :math:`(\theta_1, \theta_2, \theta_3)`,
25
            where :math:`\theta_i \sim \mathcal{U}(a, b)`.
26
            If only one value :math:`d` is provided,
27
            :math:`\theta_i \sim \mathcal{U}(-d, d)`.
28
        isotropic: If ``True``, the scaling factor along all dimensions is the
29
            same, i.e. :math:`s_1 = s_2 = s_3`.
30
        default_pad_value: As the image is rotated, some values near the
31
            borders will be undefined.
32
            If ``'minimum'``, the fill value will be the image minimum.
33
            If ``'otsu'``, the fill value is the mean of the values at the
34
            border that lie under an
35
            `Otsu threshold <https://ieeexplore.ieee.org/document/4310076>`_.
36
        image_interpolation: See :ref:`Interpolation`.
37
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
38
39
    .. note:: Rotations are performed around the center of the image.
40
41
    Example:
42
        >>> from torchio.transforms import RandomAffine, Interpolation
43
        >>> sample = images_dataset[0]  # instance of torchio.ImagesDataset
44
        >>> transform = RandomAffine(
45
        ...     scales=(0.9, 1.2),
46
        ...     degrees=(10),
47
        ...     isotropic=False,
48
        ...     default_pad_value='otsu',
49
        ...     image_interpolation=Interpolation.BSPLINE,
50
        ... )
51
        >>> transformed = transform(sample)
52
53
    From the command line::
54
55
        $ torchio-transform t1.nii.gz RandomAffine -k "degrees=30 default_pad_value=minimum" -s 42 affine_min.nii.gz
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (116/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
56
57
    """
58
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (7/5)
Loading history...
59
            self,
60
            scales: Tuple[float, float] = (0.9, 1.1),
61
            degrees: TypeRangeFloat = 10,
62
            isotropic: bool = False,
63
            default_pad_value: Union[str, float] = 'otsu',
64
            image_interpolation: Interpolation = Interpolation.LINEAR,
65
            seed: Optional[int] = None,
66
            ):
67
        super().__init__(seed=seed)
68
        self.scales = scales
69
        self.degrees = self.parse_degrees(degrees)
70
        self.isotropic = isotropic
71
        self.default_pad_value = self.parse_default_value(default_pad_value)
72
        self.interpolation = self.parse_interpolation(image_interpolation)
73
74
    @staticmethod
75
    def parse_default_value(value: Union[str, float]) -> Union[str, float]:
76
        if isinstance(value, Number) or value in ('minimum', 'otsu'):
77
            return value
78
        message = (
79
            'Value for default_pad_value must be "minimum", "otsu"'
80
            ' or a number'
81
        )
82
        raise ValueError(message)
83
84
    def apply_transform(self, sample: dict) -> dict:
85
        check_consistent_shape(sample)
86
        scaling_params, rotation_params = self.get_params(
87
            self.scales, self.degrees, self.isotropic)
88
        sample['random_scaling'] = scaling_params
89
        sample['random_rotation'] = rotation_params
90
        for image_dict in sample.values():
91
            if not is_image_dict(image_dict):
92
                continue
93
            if image_dict[TYPE] == LABEL:
94
                interpolation = Interpolation.NEAREST
95
            else:
96
                interpolation = self.interpolation
97
            image_dict[DATA] = self.apply_affine_transform(
98
                image_dict[DATA],
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
99
                image_dict[AFFINE],
100
                scaling_params,
101
                rotation_params,
102
                interpolation,
103
            )
104
        return sample
105
106
    @staticmethod
107
    def get_params(
108
            scales: Tuple[float, float],
109
            degrees: Tuple[float, float],
110
            isotropic: bool,
111
            ) -> Tuple[List[float], List[float]]:
112
        scaling_params = torch.FloatTensor(3).uniform_(*scales)
113
        if isotropic:
114
            scaling_params.fill_(scaling_params[0])
115
        rotation_params = torch.FloatTensor(3).uniform_(*degrees)
116
        return scaling_params.tolist(), rotation_params.tolist()
117
118
    @staticmethod
119
    def get_scaling_transform(
120
            scaling_params: List[float],
121
            ) -> sitk.ScaleTransform:
122
        """
123
        scaling_params are inverted so that they are more intuitive
124
        For example, 1.5 means the objects look 1.5 times larger
125
        """
126
        transform = sitk.ScaleTransform(3)
127
        scaling_params = 1 / np.array(scaling_params)
128
        transform.SetScale(scaling_params)
129
        return transform
130
131
    @staticmethod
132
    def get_rotation_transform(
133
            degrees: List[float],
134
            ) -> sitk.Euler3DTransform:
135
        transform = sitk.Euler3DTransform()
136
        radians = np.radians(degrees)
137
        transform.SetRotation(*radians)
138
        return transform
139
140
    def apply_affine_transform(
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
Comprehensibility introduced by
This function exceeds the maximum number of variables (16/15).
Loading history...
141
            self,
142
            tensor: torch.Tensor,
143
            affine: np.ndarray,
144
            scaling_params: List[float],
145
            rotation_params: List[float],
146
            interpolation: Interpolation,
147
            ) -> torch.Tensor:
148
        assert tensor.ndim == 4
149
        assert len(tensor) == 1
150
151
        image = self.nib_to_sitk(tensor[0], affine)
152
        floating = reference = image
153
154
        scaling_transform = self.get_scaling_transform(scaling_params)
155
        rotation_transform = self.get_rotation_transform(rotation_params)
156
        transform = sitk.Transform(3, sitk.sitkComposite)
157
        transform.AddTransform(scaling_transform)
158
        transform.AddTransform(rotation_transform)
159
160
        if self.default_pad_value == 'minimum':
161
            default_value = tensor.min().item()
0 ignored issues
show
Unused Code introduced by
The variable default_value seems to be unused.
Loading history...
162
        elif self.default_pad_value == 'otsu':
163
            default_value = get_borders_otsu(image)
164
        else:
165
            default_value = self.default_pad_value
166
167
        resampler = sitk.ResampleImageFilter()
168
        resampler.SetInterpolator(get_sitk_interpolator(interpolation))
169
        resampler.SetReferenceImage(reference)
170
        resampler.SetDefaultPixelValue(tensor.min().item())
171
        resampler.SetOutputPixelType(sitk.sitkFloat32)
172
        resampler.SetTransform(transform)
173
        resampled = resampler.Execute(floating)
174
175
        np_array = sitk.GetArrayFromImage(resampled)
176
        np_array = np_array.transpose()  # ITK to NumPy
177
        tensor[0] = torch.from_numpy(np_array)
178
        return tensor
179
180
181
def get_borders_otsu(image):
182
    array = sitk.GetArrayViewFromImage(image)
183
    borders = np.array((
184
        array[0],
185
        array[-1],
186
        array[0, :, :],
187
        array[-1, :, :],
188
        array[:, 0, :],
189
        array[:, -1, :],
190
        array[:, :, 0],
191
        array[:, :, -1],
192
    ))
193
    borders = np.hstack([border.flatten() for border in borders])
194
    borders = borders.reshape(1, 1, -1)
195
    borders_image = sitk.GetImageFromArray(borders)
196
    otsu = sitk.OtsuThresholdImageFilter()
197
    otsu.Execute(borders_image)
198
    threshold = otsu.GetThreshold()
199
    values = borders[borders < threshold]
200
    default_value = values.mean()
201
    return default_value
202