Completed
Pull Request — master (#353)
by Fernando
118:39 queued 117:31
created

RandomGhosting.add_artifact()   C

Complexity

Conditions 9

Size

Total Lines 43
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 33
nop 6
dl 0
loc 43
rs 6.6666
c 0
b 0
f 0
1
from collections import defaultdict
2
from typing import Tuple, Optional, Union, Sequence, Dict
3
4
import torch
5
import numpy as np
6
7
from ....torchio import DATA
8
from ....data.subject import Subject
9
from ... import IntensityTransform, FourierTransform
10
from .. import RandomTransform
11
12
13
class RandomGhosting(RandomTransform, IntensityTransform):
14
    r"""Add random MRI ghosting artifact.
15
16
    Discrete "ghost" artifacts may occur along the phase-encode direction
17
    whenever the position or signal intensity of imaged structures within the
18
    field-of-view vary or move in a regular (periodic) fashion. Pulsatile flow
19
    of blood or CSF, cardiac motion, and respiratory motion are the most
20
    important patient-related causes of ghost artifacts in clinical MR imaging
21
    (from `mriquestions.com <http://mriquestions.com/why-discrete-ghosts.html>`_).
22
23
    Args:
24
        num_ghosts: Number of 'ghosts' :math:`n` in the image.
25
            If :attr:`num_ghosts` is a tuple :math:`(a, b)`, then
26
            :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
27
            If only one value :math:`d` is provided,
28
            :math:`n \sim \mathcal{U}(0, d) \cap \mathbb{N}`.
29
        axes: Axis along which the ghosts will be created. If
30
            :attr:`axes` is a tuple, the axis will be randomly chosen
31
            from the passed values. Anatomical labels may also be used (see
32
            :class:`~torchio.transforms.augmentation.RandomFlip`).
33
        intensity: Positive number representing the artifact strength
34
            :math:`s` with respect to the maximum of the :math:`k`-space.
35
            If ``0``, the ghosts will not be visible. If a tuple
36
            :math:`(a, b)` is provided then :math:`s \sim \mathcal{U}(a, b)`.
37
            If only one value :math:`d` is provided,
38
            :math:`s \sim \mathcal{U}(0, d)`.
39
        restore: Number between ``0`` and ``1`` indicating how much of the
40
            :math:`k`-space center should be restored after removing the planes
41
            that generate the artifact.
42
        p: Probability that this transform will be applied.
43
        keys: See :class:`~torchio.transforms.Transform`.
44
45
    .. note:: The execution time of this transform does not depend on the
46
        number of ghosts.
47
    """
48
    def __init__(
49
            self,
50
            num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
51
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
52
            intensity: Union[float, Tuple[float, float]] = (0.5, 1),
53
            restore: float = 0.02,
54
            p: float = 1,
55
            keys: Optional[Sequence[str]] = None,
56
            ):
57
        super().__init__(p=p, keys=keys)
58
        if not isinstance(axes, tuple):
59
            try:
60
                axes = tuple(axes)
61
            except TypeError:
62
                axes = (axes,)
63
        for axis in axes:
64
            if not isinstance(axis, str) and axis not in (0, 1, 2):
65
                raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"')
66
        self.axes = axes
67
        self.num_ghosts_range = self.parse_range(
68
            num_ghosts, 'num_ghosts', min_constraint=0, type_constraint=int)
69
        self.intensity_range = self.parse_range(
70
            intensity, 'intensity_range', min_constraint=0)
71
        self.restore = _parse_restore(restore)
72
73
    def apply_transform(self, subject: Subject) -> Subject:
74
        arguments = defaultdict(dict)
75
        if any(isinstance(n, str) for n in self.axes):
76
            subject.check_consistent_orientation()
77
        for name, image in self.get_images_dict(subject).items():
78
            is_2d = image.is_2d()
79
            axes = [a for a in self.axes if a != 2] if is_2d else self.axes
80
            params = self.get_params(
81
                self.num_ghosts_range,
82
                axes,
83
                self.intensity_range,
84
            )
85
            num_ghosts_param, axis_param, intensity_param = params
86
            arguments['num_ghosts'][name] = num_ghosts_param
87
            arguments['axis'][name] = axis_param
88
            arguments['intensity'][name] = intensity_param
89
            arguments['restore'][name] = self.restore
90
        transform = Ghosting(**arguments)
91
        transformed = transform(subject)
92
        return transformed
93
94
    def get_params(
95
            self,
96
            num_ghosts_range: Tuple[int, int],
97
            axes: Tuple[int, ...],
98
            intensity_range: Tuple[float, float],
99
            ) -> Tuple:
100
        ng_min, ng_max = num_ghosts_range
101
        num_ghosts = torch.randint(ng_min, ng_max + 1, (1,)).item()
102
        axis = axes[torch.randint(0, len(axes), (1,))]
103
        intensity = self.sample_uniform(*intensity_range).item()
104
        return num_ghosts, axis, intensity
105
106
107
class Ghosting(IntensityTransform, FourierTransform):
108
    r"""Add MRI ghosting artifact.
109
110
    Discrete "ghost" artifacts may occur along the phase-encode direction
111
    whenever the position or signal intensity of imaged structures within the
112
    field-of-view vary or move in a regular (periodic) fashion. Pulsatile flow
113
    of blood or CSF, cardiac motion, and respiratory motion are the most
114
    important patient-related causes of ghost artifacts in clinical MR imaging
115
    (from `mriquestions.com <http://mriquestions.com/why-discrete-ghosts.html>`_).
116
117
    Args:
118
        num_ghosts: Number of 'ghosts' :math:`n` in the image.
119
        axes: Axis along which the ghosts will be created.
120
        intensity: Positive number representing the artifact strength
121
            :math:`s` with respect to the maximum of the :math:`k`-space.
122
            If ``0``, the ghosts will not be visible.
123
        restore: Number between ``0`` and ``1`` indicating how much of the
124
            :math:`k`-space center should be restored after removing the planes
125
            that generate the artifact.
126
        keys: See :class:`~torchio.transforms.Transform`.
127
128
    .. note:: The execution time of this transform does not depend on the
129
        number of ghosts.
130
    """
131
    def __init__(
132
            self,
133
            num_ghosts: Union[int, Dict[str, int]],
134
            axis: Union[int, Dict[str, int]],
135
            intensity: Union[float, Dict[str, float]],
136
            restore: Union[float, Dict[str, float]],
137
            keys: Optional[Sequence[str]] = None,
138
            ):
139
        super().__init__(keys=keys)
140
        self.axis = axis
141
        self.num_ghosts = num_ghosts
142
        self.intensity = intensity
143
        self.restore = restore
144
        self.args_names = 'num_ghosts', 'axis', 'intensity', 'restore'
145
146
    def apply_transform(self, subject: Subject) -> Subject:
147
        axis = self.axis
148
        num_ghosts = self.num_ghosts
149
        intensity = self.intensity
150
        restore = self.restore
151
        for name, image in self.get_images_dict(subject).items():
152
            if self.arguments_are_dict():
153
                axis = self.axis[name]
154
                num_ghosts = self.num_ghosts[name]
155
                intensity = self.intensity[name]
156
                restore = self.restore[name]
157
            transformed_tensors = []
158
            for tensor in image.data:
159
                transformed_tensor = self.add_artifact(
160
                    tensor,
161
                    num_ghosts,
162
                    axis,
163
                    intensity,
164
                    restore,
165
                )
166
                transformed_tensors.append(transformed_tensor)
167
            image[DATA] = torch.stack(transformed_tensors)
168
        return subject
169
170
    def add_artifact(
171
            self,
172
            tensor: torch.Tensor,
173
            num_ghosts: int,
174
            axis: int,
175
            intensity: float,
176
            restore_center: float,
177
            ):
178
        if not num_ghosts or not intensity:
179
            return tensor
180
181
        array = tensor.numpy()
182
        spectrum = self.fourier_transform(array)
183
184
        shape = np.array(array.shape)
185
        ri, rj, rk = np.round(restore_center * shape).astype(np.uint16)
186
        mi, mj, mk = np.array(array.shape) // 2
187
188
        # Variable "planes" is the part of the spectrum that will be modified
189
        if axis == 0:
190
            planes = spectrum[::num_ghosts, :, :]
191
            restore = spectrum[mi, :, :].copy()
192
        elif axis == 1:
193
            planes = spectrum[:, ::num_ghosts, :]
194
            restore = spectrum[:, mj, :].copy()
195
        elif axis == 2:
196
            planes = spectrum[:, :, ::num_ghosts]
197
            restore = spectrum[:, :, mk].copy()
198
199
        # Multiply by 0 if intensity is 1
200
        planes *= 1 - intensity
0 ignored issues
show
introduced by
The variable planes does not seem to be defined for all execution paths.
Loading history...
201
202
        # Restore the center of k-space to avoid extreme artifacts
203
        if axis == 0:
204
            spectrum[mi, :, :] = restore
0 ignored issues
show
introduced by
The variable restore does not seem to be defined for all execution paths.
Loading history...
205
        elif axis == 1:
206
            spectrum[:, mj, :] = restore
207
        elif axis == 2:
208
            spectrum[:, :, mk] = restore
209
210
        array_ghosts = self.inv_fourier_transform(spectrum)
211
        array_ghosts = np.real(array_ghosts)
212
        return torch.from_numpy(array_ghosts)
213
214
215
def _parse_restore(restore):
216
    if not isinstance(restore, float):
217
        raise TypeError(f'Restore must be a float, not {restore}')
218
    if not 0 <= restore <= 1:
219
        message = (
220
            f'Restore must be a number between 0 and 1, not {restore}')
221
        raise ValueError(message)
222
    return restore
223