Completed
Pull Request — master (#353)
by Fernando
118:39 queued 117:31
created

Motion.__init__()   A

Complexity

Conditions 1

Size

Total Lines 18
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 17
nop 6
dl 0
loc 18
rs 9.55
c 0
b 0
f 0
1
from collections import defaultdict
2
from typing import Tuple, Optional, Sequence, List, Union, Dict
3
4
import torch
5
import numpy as np
6
import SimpleITK as sitk
7
8
from ....utils import nib_to_sitk
9
from ....torchio import DATA, AFFINE, TypeTripletFloat
10
from ....data.subject import Subject
11
from ... import IntensityTransform, FourierTransform
12
from .. import RandomTransform
13
14
15
class RandomMotion(RandomTransform, IntensityTransform, FourierTransform):
16
    r"""Add random MRI motion artifact.
17
18
    Magnetic resonance images suffer from motion artifacts when the subject
19
    moves during image acquisition. This transform follows
20
    `Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to
21
    simulate motion artifacts for data augmentation.
22
23
    Args:
24
        degrees: Tuple :math:`(a, b)` defining the rotation range in degrees of
25
            the simulated movements. The rotation angles around each axis are
26
            :math:`(\theta_1, \theta_2, \theta_3)`,
27
            where :math:`\theta_i \sim \mathcal{U}(a, b)`.
28
            If only one value :math:`d` is provided,
29
            :math:`\theta_i \sim \mathcal{U}(-d, d)`.
30
            Larger values generate more distorted images.
31
        translation: Tuple :math:`(a, b)` defining the translation in mm of
32
            the simulated movements. The translations along each axis are
33
            :math:`(t_1, t_2, t_3)`,
34
            where :math:`t_i \sim \mathcal{U}(a, b)`.
35
            If only one value :math:`t` is provided,
36
            :math:`t_i \sim \mathcal{U}(-t, t)`.
37
            Larger values generate more distorted images.
38
        num_transforms: Number of simulated movements.
39
            Larger values generate more distorted images.
40
        image_interpolation: See :ref:`Interpolation`.
41
        p: Probability that this transform will be applied.
42
        keys: See :class:`~torchio.transforms.Transform`.
43
44
    .. warning:: Large numbers of movements lead to longer execution times for
45
        3D images.
46
    """
47
    def __init__(
48
            self,
49
            degrees: float = 10,
50
            translation: float = 10,  # in mm
51
            num_transforms: int = 2,
52
            image_interpolation: str = 'linear',
53
            p: float = 1,
54
            keys: Optional[Sequence[str]] = None,
55
            ):
56
        super().__init__(p=p, keys=keys)
57
        self.degrees_range = self.parse_degrees(degrees)
58
        self.translation_range = self.parse_translation(translation)
59
        if not 0 < num_transforms or not isinstance(num_transforms, int):
60
            message = (
61
                'Number of transforms must be a strictly positive natural'
62
                f'number, not {num_transforms}'
63
            )
64
            raise ValueError(message)
65
        self.num_transforms = num_transforms
66
        self.image_interpolation = self.parse_interpolation(image_interpolation)
67
68
    def apply_transform(self, subject: Subject) -> Subject:
69
        arguments = defaultdict(dict)
70
        for name, image in self.get_images_dict(subject).items():
71
            params = self.get_params(
72
                self.degrees_range,
73
                self.translation_range,
74
                self.num_transforms,
75
                is_2d=image.is_2d(),
76
            )
77
            times_params, degrees_params, translation_params = params
78
            arguments['times'][name] = times_params
79
            arguments['degrees'][name] = degrees_params
80
            arguments['translation'][name] = translation_params
81
            arguments['image_interpolation'][name] = self.image_interpolation
82
        transform = Motion(**arguments)
83
        transformed = transform(subject)
84
        return transformed
85
86
    def get_params(
87
            self,
88
            degrees_range: Tuple[float, float],
89
            translation_range: Tuple[float, float],
90
            num_transforms: int,
91
            perturbation: float = 0.3,
92
            is_2d: bool = False,
93
            ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
94
        # If perturbation is 0, time intervals between movements are constant
95
        degrees_params = self.get_params_array(
96
            degrees_range, num_transforms)
97
        translation_params = self.get_params_array(
98
            translation_range, num_transforms)
99
        if is_2d:  # imagine sagittal (1, A, S)
100
            degrees_params[:, :-1] = 0  # rotate around Z axis only
101
            translation_params[:, 2] = 0  # translate in XY plane only
102
        step = 1 / (num_transforms + 1)
103
        times = torch.arange(0, 1, step)[1:]
104
        noise = torch.FloatTensor(num_transforms)
105
        noise.uniform_(-step * perturbation, step * perturbation)
106
        times += noise
107
        times_params = times.numpy()
108
        return times_params, degrees_params, translation_params
109
110
    @staticmethod
111
    def get_params_array(nums_range: Tuple[float, float], num_transforms: int):
112
        tensor = torch.FloatTensor(num_transforms, 3).uniform_(*nums_range)
113
        return tensor.numpy()
114
115
116
class Motion(IntensityTransform, FourierTransform):
117
    r"""Add MRI motion artifact.
118
119
    Magnetic resonance images suffer from motion artifacts when the subject
120
    moves during image acquisition. This transform follows
121
    `Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to
122
    simulate motion artifacts for data augmentation.
123
124
    Args:
125
        degrees: Sequence of rotations :math:`(\theta_1, \theta_2, \theta_3)`.
126
        translation: Sequence of translations :math:`(t_1, t_2, t_3)` in mm.
127
        times: Sequence of times from 0 to 1 at which the motions happen.
128
        image_interpolation: See :ref:`Interpolation`.
129
        keys: See :class:`~torchio.transforms.Transform`.
130
    """
131
    def __init__(
132
            self,
133
            degrees: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]],
134
            translation: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]],
135
            times: Union[Sequence[float], Dict[str, Sequence[float]]],
136
            image_interpolation: Union[Sequence[str], Dict[str, Sequence[str]]],
137
            keys: Optional[Sequence[str]] = None,
138
            ):
139
        super().__init__(keys=keys)
140
        self.degrees = degrees
141
        self.translation = translation
142
        self.times = times
143
        self.image_interpolation = image_interpolation
144
        self.args_names = (
145
            'degrees',
146
            'translation',
147
            'times',
148
            'image_interpolation',
149
        )
150
151
    def apply_transform(self, subject: Subject) -> Subject:
152
        degrees = self.degrees
153
        translation = self.translation
154
        times = self.times
155
        image_interpolation = self.image_interpolation
156
        for image_name, image in self.get_images_dict(subject).items():
157
            if self.arguments_are_dict():
158
                degrees = self.degrees[image_name]
159
                translation = self.translation[image_name]
160
                times = self.times[image_name]
161
                image_interpolation = self.image_interpolation[image_name]
162
            result_arrays = []
163
            for data in image[DATA]:
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
164
                sitk_image = nib_to_sitk(
165
                    data[np.newaxis],
166
                    image[AFFINE],
167
                    force_3d=True,
168
                )
169
                transforms = self.get_rigid_transforms(
170
                    degrees,
171
                    translation,
172
                    sitk_image,
173
                )
174
                data = self.add_artifact(
175
                    sitk_image,
176
                    transforms,
177
                    times,
178
                    image_interpolation,
179
                )
180
                result_arrays.append(data)
181
            result = np.stack(result_arrays)
182
            image[DATA] = torch.from_numpy(result)
183
        return subject
184
185
    def get_rigid_transforms(
186
            self,
187
            degrees_params: np.ndarray,
188
            translation_params: np.ndarray,
189
            image: sitk.Image,
190
            ) -> List[sitk.Euler3DTransform]:
191
        center_ijk = np.array(image.GetSize()) / 2
192
        center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk)
193
        identity = np.eye(4)
194
        matrices = [identity]
195
        for degrees, translation in zip(degrees_params, translation_params):
196
            radians = np.radians(degrees).tolist()
197
            motion = sitk.Euler3DTransform()
198
            motion.SetCenter(center_lps)
199
            motion.SetRotation(*radians)
200
            motion.SetTranslation(translation.tolist())
201
            motion_matrix = self.transform_to_matrix(motion)
202
            matrices.append(motion_matrix)
203
        transforms = [self.matrix_to_transform(m) for m in matrices]
204
        return transforms
205
206
    @staticmethod
207
    def transform_to_matrix(transform: sitk.Euler3DTransform) -> np.ndarray:
208
        matrix = np.eye(4)
209
        rotation = np.array(transform.GetMatrix()).reshape(3, 3)
210
        matrix[:3, :3] = rotation
211
        matrix[:3, 3] = transform.GetTranslation()
212
        return matrix
213
214
    @staticmethod
215
    def matrix_to_transform(matrix: np.ndarray) -> sitk.Euler3DTransform:
216
        transform = sitk.Euler3DTransform()
217
        rotation = matrix[:3, :3].flatten().tolist()
218
        transform.SetMatrix(rotation)
219
        transform.SetTranslation(matrix[:3, 3])
220
        return transform
221
222
    def resample_images(
223
            self,
224
            image: sitk.Image,
225
            transforms: Sequence[sitk.Euler3DTransform],
226
            interpolation: str,
227
            ) -> List[sitk.Image]:
228
        floating = reference = image
229
        default_value = np.float64(sitk.GetArrayViewFromImage(image).min())
230
        transforms = transforms[1:]  # first is identity
231
        images = [image]  # first is identity
232
        for transform in transforms:
233
            resampler = sitk.ResampleImageFilter()
234
            resampler.SetInterpolator(self.get_sitk_interpolator(interpolation))
235
            resampler.SetReferenceImage(reference)
236
            resampler.SetOutputPixelType(sitk.sitkFloat32)
237
            resampler.SetDefaultPixelValue(default_value)
238
            resampler.SetTransform(transform)
239
            resampled = resampler.Execute(floating)
240
            images.append(resampled)
241
        return images
242
243
    @staticmethod
244
    def sort_spectra(spectra: np.ndarray, times: np.ndarray):
245
        """Use original spectrum to fill the center of k-space"""
246
        num_spectra = len(spectra)
247
        if np.any(times > 0.5):
248
            index = np.where(times > 0.5)[0].min()
249
        else:
250
            index = num_spectra - 1
251
        spectra[0], spectra[index] = spectra[index], spectra[0]
252
253
    def add_artifact(
254
            self,
255
            image: sitk.Image,
256
            transforms: Sequence[sitk.Euler3DTransform],
257
            times: np.ndarray,
258
            interpolation: str,
259
            ):
260
        images = self.resample_images(image, transforms, interpolation)
261
        arrays = [sitk.GetArrayViewFromImage(im) for im in images]
262
        arrays = [array.transpose() for array in arrays]  # ITK to NumPy
263
        spectra = [self.fourier_transform(array) for array in arrays]
264
        self.sort_spectra(spectra, times)
265
        result_spectrum = np.empty_like(spectra[0])
266
        last_index = result_spectrum.shape[2]
267
        indices = (last_index * times).astype(int).tolist()
268
        indices.append(last_index)
269
        ini = 0
270
        for spectrum, fin in zip(spectra, indices):
271
            result_spectrum[..., ini:fin] = spectrum[..., ini:fin]
272
            ini = fin
273
        result_image = np.real(self.inv_fourier_transform(result_spectrum))
274
        return result_image.astype(np.float32)
275