|
1
|
|
|
from collections import defaultdict |
|
2
|
|
|
from typing import Tuple, Optional, Sequence, List, Union, Dict |
|
3
|
|
|
|
|
4
|
|
|
import torch |
|
5
|
|
|
import numpy as np |
|
6
|
|
|
import SimpleITK as sitk |
|
7
|
|
|
|
|
8
|
|
|
from ....utils import nib_to_sitk |
|
9
|
|
|
from ....torchio import DATA, AFFINE, TypeTripletFloat |
|
10
|
|
|
from ....data.subject import Subject |
|
11
|
|
|
from ... import IntensityTransform, FourierTransform |
|
12
|
|
|
from .. import RandomTransform |
|
13
|
|
|
|
|
14
|
|
|
|
|
15
|
|
|
class RandomMotion(RandomTransform, IntensityTransform, FourierTransform): |
|
16
|
|
|
r"""Add random MRI motion artifact. |
|
17
|
|
|
|
|
18
|
|
|
Magnetic resonance images suffer from motion artifacts when the subject |
|
19
|
|
|
moves during image acquisition. This transform follows |
|
20
|
|
|
`Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to |
|
21
|
|
|
simulate motion artifacts for data augmentation. |
|
22
|
|
|
|
|
23
|
|
|
Args: |
|
24
|
|
|
degrees: Tuple :math:`(a, b)` defining the rotation range in degrees of |
|
25
|
|
|
the simulated movements. The rotation angles around each axis are |
|
26
|
|
|
:math:`(\theta_1, \theta_2, \theta_3)`, |
|
27
|
|
|
where :math:`\theta_i \sim \mathcal{U}(a, b)`. |
|
28
|
|
|
If only one value :math:`d` is provided, |
|
29
|
|
|
:math:`\theta_i \sim \mathcal{U}(-d, d)`. |
|
30
|
|
|
Larger values generate more distorted images. |
|
31
|
|
|
translation: Tuple :math:`(a, b)` defining the translation in mm of |
|
32
|
|
|
the simulated movements. The translations along each axis are |
|
33
|
|
|
:math:`(t_1, t_2, t_3)`, |
|
34
|
|
|
where :math:`t_i \sim \mathcal{U}(a, b)`. |
|
35
|
|
|
If only one value :math:`t` is provided, |
|
36
|
|
|
:math:`t_i \sim \mathcal{U}(-t, t)`. |
|
37
|
|
|
Larger values generate more distorted images. |
|
38
|
|
|
num_transforms: Number of simulated movements. |
|
39
|
|
|
Larger values generate more distorted images. |
|
40
|
|
|
image_interpolation: See :ref:`Interpolation`. |
|
41
|
|
|
p: Probability that this transform will be applied. |
|
42
|
|
|
keys: See :py:class:`~torchio.transforms.Transform`. |
|
43
|
|
|
|
|
44
|
|
|
.. warning:: Large numbers of movements lead to longer execution times for |
|
45
|
|
|
3D images. |
|
46
|
|
|
""" |
|
47
|
|
|
def __init__( |
|
48
|
|
|
self, |
|
49
|
|
|
degrees: float = 10, |
|
50
|
|
|
translation: float = 10, # in mm |
|
51
|
|
|
num_transforms: int = 2, |
|
52
|
|
|
image_interpolation: str = 'linear', |
|
53
|
|
|
p: float = 1, |
|
54
|
|
|
keys: Optional[Sequence[str]] = None, |
|
55
|
|
|
): |
|
56
|
|
|
super().__init__(p=p, keys=keys) |
|
57
|
|
|
self.degrees_range = self.parse_degrees(degrees) |
|
58
|
|
|
self.translation_range = self.parse_translation(translation) |
|
59
|
|
|
if not 0 < num_transforms or not isinstance(num_transforms, int): |
|
60
|
|
|
message = ( |
|
61
|
|
|
'Number of transforms must be a strictly positive natural' |
|
62
|
|
|
f'number, not {num_transforms}' |
|
63
|
|
|
) |
|
64
|
|
|
raise ValueError(message) |
|
65
|
|
|
self.num_transforms = num_transforms |
|
66
|
|
|
self.image_interpolation = self.parse_interpolation(image_interpolation) |
|
67
|
|
|
|
|
68
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
|
69
|
|
|
arguments = defaultdict(dict) |
|
70
|
|
|
for name, image in self.get_images_dict(subject).items(): |
|
71
|
|
|
params = self.get_params( |
|
72
|
|
|
self.degrees_range, |
|
73
|
|
|
self.translation_range, |
|
74
|
|
|
self.num_transforms, |
|
75
|
|
|
is_2d=image.is_2d(), |
|
76
|
|
|
) |
|
77
|
|
|
times_params, degrees_params, translation_params = params |
|
78
|
|
|
arguments['times'][name] = times_params |
|
79
|
|
|
arguments['degrees'][name] = degrees_params |
|
80
|
|
|
arguments['translation'][name] = translation_params |
|
81
|
|
|
arguments['image_interpolation'][name] = self.image_interpolation |
|
82
|
|
|
transform = Motion(**arguments) |
|
83
|
|
|
transformed = transform(subject) |
|
84
|
|
|
return transformed |
|
85
|
|
|
|
|
86
|
|
|
def get_params( |
|
87
|
|
|
self, |
|
88
|
|
|
degrees_range: Tuple[float, float], |
|
89
|
|
|
translation_range: Tuple[float, float], |
|
90
|
|
|
num_transforms: int, |
|
91
|
|
|
perturbation: float = 0.3, |
|
92
|
|
|
is_2d: bool = False, |
|
93
|
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
|
94
|
|
|
# If perturbation is 0, time intervals between movements are constant |
|
95
|
|
|
degrees_params = self.get_params_array( |
|
96
|
|
|
degrees_range, num_transforms) |
|
97
|
|
|
translation_params = self.get_params_array( |
|
98
|
|
|
translation_range, num_transforms) |
|
99
|
|
|
if is_2d: # imagine sagittal (1, A, S) |
|
100
|
|
|
degrees_params[:, :-1] = 0 # rotate around Z axis only |
|
101
|
|
|
translation_params[:, 2] = 0 # translate in XY plane only |
|
102
|
|
|
step = 1 / (num_transforms + 1) |
|
103
|
|
|
times = torch.arange(0, 1, step)[1:] |
|
104
|
|
|
noise = torch.FloatTensor(num_transforms) |
|
105
|
|
|
noise.uniform_(-step * perturbation, step * perturbation) |
|
106
|
|
|
times += noise |
|
107
|
|
|
times_params = times.numpy() |
|
108
|
|
|
return times_params, degrees_params, translation_params |
|
109
|
|
|
|
|
110
|
|
|
@staticmethod |
|
111
|
|
|
def get_params_array(nums_range: Tuple[float, float], num_transforms: int): |
|
112
|
|
|
tensor = torch.FloatTensor(num_transforms, 3).uniform_(*nums_range) |
|
113
|
|
|
return tensor.numpy() |
|
114
|
|
|
|
|
115
|
|
|
|
|
116
|
|
|
class Motion(IntensityTransform, FourierTransform): |
|
117
|
|
|
r"""Add MRI motion artifact. |
|
118
|
|
|
|
|
119
|
|
|
Magnetic resonance images suffer from motion artifacts when the subject |
|
120
|
|
|
moves during image acquisition. This transform follows |
|
121
|
|
|
`Shaw et al., 2019 <http://proceedings.mlr.press/v102/shaw19a.html>`_ to |
|
122
|
|
|
simulate motion artifacts for data augmentation. |
|
123
|
|
|
|
|
124
|
|
|
Args: |
|
125
|
|
|
degrees: Sequence of rotations :math:`(\theta_1, \theta_2, \theta_3)`. |
|
126
|
|
|
translation: Sequence of translations :math:`(t_1, t_2, t_3)` in mm. |
|
127
|
|
|
times: Sequence of times from 0 to 1 at which the motions happen. |
|
128
|
|
|
image_interpolation: See :ref:`Interpolation`. |
|
129
|
|
|
keys: See :py:class:`~torchio.transforms.Transform`. |
|
130
|
|
|
""" |
|
131
|
|
|
def __init__( |
|
132
|
|
|
self, |
|
133
|
|
|
degrees: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]], |
|
134
|
|
|
translation: Union[TypeTripletFloat, Dict[str, TypeTripletFloat]], |
|
135
|
|
|
times: Union[Sequence[float], Dict[str, Sequence[float]]], |
|
136
|
|
|
image_interpolation: Union[Sequence[str], Dict[str, Sequence[str]]], |
|
137
|
|
|
keys: Optional[Sequence[str]] = None, |
|
138
|
|
|
): |
|
139
|
|
|
super().__init__(keys=keys) |
|
140
|
|
|
self.degrees = degrees |
|
141
|
|
|
self.translation = translation |
|
142
|
|
|
self.times = times |
|
143
|
|
|
self.image_interpolation = image_interpolation |
|
144
|
|
|
self.args_names = ( |
|
145
|
|
|
'degrees', |
|
146
|
|
|
'translation', |
|
147
|
|
|
'times', |
|
148
|
|
|
'image_interpolation', |
|
149
|
|
|
) |
|
150
|
|
|
|
|
151
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
|
152
|
|
|
degrees = self.degrees |
|
153
|
|
|
translation = self.translation |
|
154
|
|
|
times = self.times |
|
155
|
|
|
image_interpolation = self.image_interpolation |
|
156
|
|
|
for image_name, image in self.get_images_dict(subject).items(): |
|
157
|
|
|
if self.arguments_are_dict(): |
|
158
|
|
|
degrees = self.degrees[image_name] |
|
159
|
|
|
translation = self.translation[image_name] |
|
160
|
|
|
times = self.times[image_name] |
|
161
|
|
|
image_interpolation = self.image_interpolation[image_name] |
|
162
|
|
|
result_arrays = [] |
|
163
|
|
|
for data in image[DATA]: |
|
|
|
|
|
|
164
|
|
|
sitk_image = nib_to_sitk( |
|
165
|
|
|
data[np.newaxis], |
|
166
|
|
|
image[AFFINE], |
|
167
|
|
|
force_3d=True, |
|
168
|
|
|
) |
|
169
|
|
|
transforms = self.get_rigid_transforms( |
|
170
|
|
|
degrees, |
|
171
|
|
|
translation, |
|
172
|
|
|
sitk_image, |
|
173
|
|
|
) |
|
174
|
|
|
data = self.add_artifact( |
|
175
|
|
|
sitk_image, |
|
176
|
|
|
transforms, |
|
177
|
|
|
times, |
|
178
|
|
|
image_interpolation, |
|
179
|
|
|
) |
|
180
|
|
|
result_arrays.append(data) |
|
181
|
|
|
result = np.stack(result_arrays) |
|
182
|
|
|
image[DATA] = torch.from_numpy(result) |
|
183
|
|
|
return subject |
|
184
|
|
|
|
|
185
|
|
|
def get_rigid_transforms( |
|
186
|
|
|
self, |
|
187
|
|
|
degrees_params: np.ndarray, |
|
188
|
|
|
translation_params: np.ndarray, |
|
189
|
|
|
image: sitk.Image, |
|
190
|
|
|
) -> List[sitk.Euler3DTransform]: |
|
191
|
|
|
center_ijk = np.array(image.GetSize()) / 2 |
|
192
|
|
|
center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk) |
|
193
|
|
|
identity = np.eye(4) |
|
194
|
|
|
matrices = [identity] |
|
195
|
|
|
for degrees, translation in zip(degrees_params, translation_params): |
|
196
|
|
|
radians = np.radians(degrees).tolist() |
|
197
|
|
|
motion = sitk.Euler3DTransform() |
|
198
|
|
|
motion.SetCenter(center_lps) |
|
199
|
|
|
motion.SetRotation(*radians) |
|
200
|
|
|
motion.SetTranslation(translation.tolist()) |
|
201
|
|
|
motion_matrix = self.transform_to_matrix(motion) |
|
202
|
|
|
matrices.append(motion_matrix) |
|
203
|
|
|
transforms = [self.matrix_to_transform(m) for m in matrices] |
|
204
|
|
|
return transforms |
|
205
|
|
|
|
|
206
|
|
|
@staticmethod |
|
207
|
|
|
def transform_to_matrix(transform: sitk.Euler3DTransform) -> np.ndarray: |
|
208
|
|
|
matrix = np.eye(4) |
|
209
|
|
|
rotation = np.array(transform.GetMatrix()).reshape(3, 3) |
|
210
|
|
|
matrix[:3, :3] = rotation |
|
211
|
|
|
matrix[:3, 3] = transform.GetTranslation() |
|
212
|
|
|
return matrix |
|
213
|
|
|
|
|
214
|
|
|
@staticmethod |
|
215
|
|
|
def matrix_to_transform(matrix: np.ndarray) -> sitk.Euler3DTransform: |
|
216
|
|
|
transform = sitk.Euler3DTransform() |
|
217
|
|
|
rotation = matrix[:3, :3].flatten().tolist() |
|
218
|
|
|
transform.SetMatrix(rotation) |
|
219
|
|
|
transform.SetTranslation(matrix[:3, 3]) |
|
220
|
|
|
return transform |
|
221
|
|
|
|
|
222
|
|
|
def resample_images( |
|
223
|
|
|
self, |
|
224
|
|
|
image: sitk.Image, |
|
225
|
|
|
transforms: Sequence[sitk.Euler3DTransform], |
|
226
|
|
|
interpolation: str, |
|
227
|
|
|
) -> List[sitk.Image]: |
|
228
|
|
|
floating = reference = image |
|
229
|
|
|
default_value = np.float64(sitk.GetArrayViewFromImage(image).min()) |
|
230
|
|
|
transforms = transforms[1:] # first is identity |
|
231
|
|
|
images = [image] # first is identity |
|
232
|
|
|
for transform in transforms: |
|
233
|
|
|
resampler = sitk.ResampleImageFilter() |
|
234
|
|
|
resampler.SetInterpolator(self.get_sitk_interpolator(interpolation)) |
|
235
|
|
|
resampler.SetReferenceImage(reference) |
|
236
|
|
|
resampler.SetOutputPixelType(sitk.sitkFloat32) |
|
237
|
|
|
resampler.SetDefaultPixelValue(default_value) |
|
238
|
|
|
resampler.SetTransform(transform) |
|
239
|
|
|
resampled = resampler.Execute(floating) |
|
240
|
|
|
images.append(resampled) |
|
241
|
|
|
return images |
|
242
|
|
|
|
|
243
|
|
|
@staticmethod |
|
244
|
|
|
def sort_spectra(spectra: np.ndarray, times: np.ndarray): |
|
245
|
|
|
"""Use original spectrum to fill the center of k-space""" |
|
246
|
|
|
num_spectra = len(spectra) |
|
247
|
|
|
if np.any(times > 0.5): |
|
248
|
|
|
index = np.where(times > 0.5)[0].min() |
|
249
|
|
|
else: |
|
250
|
|
|
index = num_spectra - 1 |
|
251
|
|
|
spectra[0], spectra[index] = spectra[index], spectra[0] |
|
252
|
|
|
|
|
253
|
|
|
def add_artifact( |
|
254
|
|
|
self, |
|
255
|
|
|
image: sitk.Image, |
|
256
|
|
|
transforms: Sequence[sitk.Euler3DTransform], |
|
257
|
|
|
times: np.ndarray, |
|
258
|
|
|
interpolation: str, |
|
259
|
|
|
): |
|
260
|
|
|
images = self.resample_images(image, transforms, interpolation) |
|
261
|
|
|
arrays = [sitk.GetArrayViewFromImage(im) for im in images] |
|
262
|
|
|
arrays = [array.transpose() for array in arrays] # ITK to NumPy |
|
263
|
|
|
spectra = [self.fourier_transform(array) for array in arrays] |
|
264
|
|
|
self.sort_spectra(spectra, times) |
|
265
|
|
|
result_spectrum = np.empty_like(spectra[0]) |
|
266
|
|
|
last_index = result_spectrum.shape[2] |
|
267
|
|
|
indices = (last_index * times).astype(int).tolist() |
|
268
|
|
|
indices.append(last_index) |
|
269
|
|
|
ini = 0 |
|
270
|
|
|
for spectrum, fin in zip(spectra, indices): |
|
271
|
|
|
result_spectrum[..., ini:fin] = spectrum[..., ini:fin] |
|
272
|
|
|
ini = fin |
|
273
|
|
|
result_image = np.real(self.inv_fourier_transform(result_spectrum)) |
|
274
|
|
|
return result_image.astype(np.float32) |
|
275
|
|
|
|