|
1
|
|
|
from collections import defaultdict |
|
2
|
|
|
from typing import Tuple, Optional, Union, Sequence, Dict |
|
3
|
|
|
|
|
4
|
|
|
import torch |
|
5
|
|
|
import numpy as np |
|
6
|
|
|
|
|
7
|
|
|
from ....torchio import DATA |
|
8
|
|
|
from ....data.subject import Subject |
|
9
|
|
|
from ... import IntensityTransform, FourierTransform |
|
10
|
|
|
from .. import RandomTransform |
|
11
|
|
|
|
|
12
|
|
|
|
|
13
|
|
|
class RandomGhosting(RandomTransform, IntensityTransform): |
|
14
|
|
|
r"""Add random MRI ghosting artifact. |
|
15
|
|
|
|
|
16
|
|
|
Discrete "ghost" artifacts may occur along the phase-encode direction |
|
17
|
|
|
whenever the position or signal intensity of imaged structures within the |
|
18
|
|
|
field-of-view vary or move in a regular (periodic) fashion. Pulsatile flow |
|
19
|
|
|
of blood or CSF, cardiac motion, and respiratory motion are the most |
|
20
|
|
|
important patient-related causes of ghost artifacts in clinical MR imaging |
|
21
|
|
|
(from `mriquestions.com <http://mriquestions.com/why-discrete-ghosts.html>`_). |
|
22
|
|
|
|
|
23
|
|
|
Args: |
|
24
|
|
|
num_ghosts: Number of 'ghosts' :math:`n` in the image. |
|
25
|
|
|
If :py:attr:`num_ghosts` is a tuple :math:`(a, b)`, then |
|
26
|
|
|
:math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`. |
|
27
|
|
|
If only one value :math:`d` is provided, |
|
28
|
|
|
:math:`n \sim \mathcal{U}(0, d) \cap \mathbb{N}`. |
|
29
|
|
|
axes: Axis along which the ghosts will be created. If |
|
30
|
|
|
:py:attr:`axes` is a tuple, the axis will be randomly chosen |
|
31
|
|
|
from the passed values. Anatomical labels may also be used (see |
|
32
|
|
|
:py:class:`~torchio.transforms.augmentation.RandomFlip`). |
|
33
|
|
|
intensity: Positive number representing the artifact strength |
|
34
|
|
|
:math:`s` with respect to the maximum of the :math:`k`-space. |
|
35
|
|
|
If ``0``, the ghosts will not be visible. If a tuple |
|
36
|
|
|
:math:`(a, b)` is provided then :math:`s \sim \mathcal{U}(a, b)`. |
|
37
|
|
|
If only one value :math:`d` is provided, |
|
38
|
|
|
:math:`s \sim \mathcal{U}(0, d)`. |
|
39
|
|
|
restore: Number between ``0`` and ``1`` indicating how much of the |
|
40
|
|
|
:math:`k`-space center should be restored after removing the planes |
|
41
|
|
|
that generate the artifact. |
|
42
|
|
|
p: Probability that this transform will be applied. |
|
43
|
|
|
keys: See :py:class:`~torchio.transforms.Transform`. |
|
44
|
|
|
|
|
45
|
|
|
.. note:: The execution time of this transform does not depend on the |
|
46
|
|
|
number of ghosts. |
|
47
|
|
|
""" |
|
48
|
|
|
def __init__( |
|
49
|
|
|
self, |
|
50
|
|
|
num_ghosts: Union[int, Tuple[int, int]] = (4, 10), |
|
51
|
|
|
axes: Union[int, Tuple[int, ...]] = (0, 1, 2), |
|
52
|
|
|
intensity: Union[float, Tuple[float, float]] = (0.5, 1), |
|
53
|
|
|
restore: float = 0.02, |
|
54
|
|
|
p: float = 1, |
|
55
|
|
|
keys: Optional[Sequence[str]] = None, |
|
56
|
|
|
): |
|
57
|
|
|
super().__init__(p=p, keys=keys) |
|
58
|
|
|
if not isinstance(axes, tuple): |
|
59
|
|
|
try: |
|
60
|
|
|
axes = tuple(axes) |
|
61
|
|
|
except TypeError: |
|
62
|
|
|
axes = (axes,) |
|
63
|
|
|
for axis in axes: |
|
64
|
|
|
if not isinstance(axis, str) and axis not in (0, 1, 2): |
|
65
|
|
|
raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"') |
|
66
|
|
|
self.axes = axes |
|
67
|
|
|
self.num_ghosts_range = self.parse_range( |
|
68
|
|
|
num_ghosts, 'num_ghosts', min_constraint=0, type_constraint=int) |
|
69
|
|
|
self.intensity_range = self.parse_range( |
|
70
|
|
|
intensity, 'intensity_range', min_constraint=0) |
|
71
|
|
|
self.restore = _parse_restore(restore) |
|
72
|
|
|
|
|
73
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
|
74
|
|
|
arguments = defaultdict(dict) |
|
75
|
|
|
if any(isinstance(n, str) for n in self.axes): |
|
76
|
|
|
subject.check_consistent_orientation() |
|
77
|
|
|
for name, image in self.get_images_dict(subject).items(): |
|
78
|
|
|
is_2d = image.is_2d() |
|
79
|
|
|
axes = [a for a in self.axes if a != 2] if is_2d else self.axes |
|
80
|
|
|
params = self.get_params( |
|
81
|
|
|
self.num_ghosts_range, |
|
82
|
|
|
axes, |
|
83
|
|
|
self.intensity_range, |
|
84
|
|
|
) |
|
85
|
|
|
num_ghosts_param, axis_param, intensity_param = params |
|
86
|
|
|
arguments['num_ghosts'][name] = num_ghosts_param |
|
87
|
|
|
arguments['axis'][name] = axis_param |
|
88
|
|
|
arguments['intensity'][name] = intensity_param |
|
89
|
|
|
arguments['restore'][name] = self.restore |
|
90
|
|
|
transform = Ghosting(**arguments) |
|
91
|
|
|
transformed = transform(subject) |
|
92
|
|
|
return transformed |
|
93
|
|
|
|
|
94
|
|
|
def get_params( |
|
95
|
|
|
self, |
|
96
|
|
|
num_ghosts_range: Tuple[int, int], |
|
97
|
|
|
axes: Tuple[int, ...], |
|
98
|
|
|
intensity_range: Tuple[float, float], |
|
99
|
|
|
) -> Tuple: |
|
100
|
|
|
ng_min, ng_max = num_ghosts_range |
|
101
|
|
|
num_ghosts = torch.randint(ng_min, ng_max + 1, (1,)).item() |
|
102
|
|
|
axis = axes[torch.randint(0, len(axes), (1,))] |
|
103
|
|
|
intensity = self.sample_uniform(*intensity_range).item() |
|
104
|
|
|
return num_ghosts, axis, intensity |
|
105
|
|
|
|
|
106
|
|
|
|
|
107
|
|
|
class Ghosting(IntensityTransform, FourierTransform): |
|
108
|
|
|
r"""Add MRI ghosting artifact. |
|
109
|
|
|
|
|
110
|
|
|
Discrete "ghost" artifacts may occur along the phase-encode direction |
|
111
|
|
|
whenever the position or signal intensity of imaged structures within the |
|
112
|
|
|
field-of-view vary or move in a regular (periodic) fashion. Pulsatile flow |
|
113
|
|
|
of blood or CSF, cardiac motion, and respiratory motion are the most |
|
114
|
|
|
important patient-related causes of ghost artifacts in clinical MR imaging |
|
115
|
|
|
(from `mriquestions.com <http://mriquestions.com/why-discrete-ghosts.html>`_). |
|
116
|
|
|
|
|
117
|
|
|
Args: |
|
118
|
|
|
num_ghosts: Number of 'ghosts' :math:`n` in the image. |
|
119
|
|
|
axes: Axis along which the ghosts will be created. |
|
120
|
|
|
intensity: Positive number representing the artifact strength |
|
121
|
|
|
:math:`s` with respect to the maximum of the :math:`k`-space. |
|
122
|
|
|
If ``0``, the ghosts will not be visible. |
|
123
|
|
|
restore: Number between ``0`` and ``1`` indicating how much of the |
|
124
|
|
|
:math:`k`-space center should be restored after removing the planes |
|
125
|
|
|
that generate the artifact. |
|
126
|
|
|
keys: See :py:class:`~torchio.transforms.Transform`. |
|
127
|
|
|
|
|
128
|
|
|
.. note:: The execution time of this transform does not depend on the |
|
129
|
|
|
number of ghosts. |
|
130
|
|
|
""" |
|
131
|
|
|
def __init__( |
|
132
|
|
|
self, |
|
133
|
|
|
num_ghosts: Union[int, Dict[str, int]], |
|
134
|
|
|
axis: Union[int, Dict[str, int]], |
|
135
|
|
|
intensity: Union[float, Dict[str, float]], |
|
136
|
|
|
restore: Union[float, Dict[str, float]], |
|
137
|
|
|
keys: Optional[Sequence[str]] = None, |
|
138
|
|
|
): |
|
139
|
|
|
super().__init__(keys=keys) |
|
140
|
|
|
self.axis = axis |
|
141
|
|
|
self.num_ghosts = num_ghosts |
|
142
|
|
|
self.intensity = intensity |
|
143
|
|
|
self.restore = restore |
|
144
|
|
|
self.args_names = 'num_ghosts', 'axis', 'intensity', 'restore' |
|
145
|
|
|
|
|
146
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
|
147
|
|
|
axis = self.axis |
|
148
|
|
|
num_ghosts = self.num_ghosts |
|
149
|
|
|
intensity = self.intensity |
|
150
|
|
|
restore = self.restore |
|
151
|
|
|
for name, image in self.get_images_dict(subject).items(): |
|
152
|
|
|
if self.arguments_are_dict(): |
|
153
|
|
|
axis = self.axis[name] |
|
154
|
|
|
num_ghosts = self.num_ghosts[name] |
|
155
|
|
|
intensity = self.intensity[name] |
|
156
|
|
|
restore = self.restore[name] |
|
157
|
|
|
transformed_tensors = [] |
|
158
|
|
|
for tensor in image.data: |
|
159
|
|
|
transformed_tensor = self.add_artifact( |
|
160
|
|
|
tensor, |
|
161
|
|
|
num_ghosts, |
|
162
|
|
|
axis, |
|
163
|
|
|
intensity, |
|
164
|
|
|
restore, |
|
165
|
|
|
) |
|
166
|
|
|
transformed_tensors.append(transformed_tensor) |
|
167
|
|
|
image[DATA] = torch.stack(transformed_tensors) |
|
168
|
|
|
return subject |
|
169
|
|
|
|
|
170
|
|
|
def add_artifact( |
|
171
|
|
|
self, |
|
172
|
|
|
tensor: torch.Tensor, |
|
173
|
|
|
num_ghosts: int, |
|
174
|
|
|
axis: int, |
|
175
|
|
|
intensity: float, |
|
176
|
|
|
restore_center: float, |
|
177
|
|
|
): |
|
178
|
|
|
if not num_ghosts or not intensity: |
|
179
|
|
|
return tensor |
|
180
|
|
|
|
|
181
|
|
|
array = tensor.numpy() |
|
182
|
|
|
spectrum = self.fourier_transform(array) |
|
183
|
|
|
|
|
184
|
|
|
shape = np.array(array.shape) |
|
185
|
|
|
ri, rj, rk = np.round(restore_center * shape).astype(np.uint16) |
|
186
|
|
|
mi, mj, mk = np.array(array.shape) // 2 |
|
187
|
|
|
|
|
188
|
|
|
# Variable "planes" is the part of the spectrum that will be modified |
|
189
|
|
|
if axis == 0: |
|
190
|
|
|
planes = spectrum[::num_ghosts, :, :] |
|
191
|
|
|
restore = spectrum[mi, :, :].copy() |
|
192
|
|
|
elif axis == 1: |
|
193
|
|
|
planes = spectrum[:, ::num_ghosts, :] |
|
194
|
|
|
restore = spectrum[:, mj, :].copy() |
|
195
|
|
|
elif axis == 2: |
|
196
|
|
|
planes = spectrum[:, :, ::num_ghosts] |
|
197
|
|
|
restore = spectrum[:, :, mk].copy() |
|
198
|
|
|
|
|
199
|
|
|
# Multiply by 0 if intensity is 1 |
|
200
|
|
|
planes *= 1 - intensity |
|
|
|
|
|
|
201
|
|
|
|
|
202
|
|
|
# Restore the center of k-space to avoid extreme artifacts |
|
203
|
|
|
if axis == 0: |
|
204
|
|
|
spectrum[mi, :, :] = restore |
|
|
|
|
|
|
205
|
|
|
elif axis == 1: |
|
206
|
|
|
spectrum[:, mj, :] = restore |
|
207
|
|
|
elif axis == 2: |
|
208
|
|
|
spectrum[:, :, mk] = restore |
|
209
|
|
|
|
|
210
|
|
|
array_ghosts = self.inv_fourier_transform(spectrum) |
|
211
|
|
|
array_ghosts = np.real(array_ghosts) |
|
212
|
|
|
return torch.from_numpy(array_ghosts) |
|
213
|
|
|
|
|
214
|
|
|
|
|
215
|
|
|
def _parse_restore(restore): |
|
216
|
|
|
if not isinstance(restore, float): |
|
217
|
|
|
raise TypeError(f'Restore must be a float, not {restore}') |
|
218
|
|
|
if not 0 <= restore <= 1: |
|
219
|
|
|
message = ( |
|
220
|
|
|
f'Restore must be a number between 0 and 1, not {restore}') |
|
221
|
|
|
raise ValueError(message) |
|
222
|
|
|
return restore |
|
223
|
|
|
|