Passed
Push — master ( 0a9301...f0d368 )
by Fernando
01:26
created

RandomGhosting.apply_transform()   B

Complexity

Conditions 5

Size

Total Lines 35
Code Lines 32

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 5
eloc 32
nop 2
dl 0
loc 35
rs 8.6453
c 0
b 0
f 0
1
from typing import Tuple, Optional, Union, List
2
import torch
3
import numpy as np
4
from ....torchio import DATA
5
from ....data.subject import Subject
6
from .. import RandomTransform
7
8
9
class RandomGhosting(RandomTransform):
10
    r"""Add random MRI ghosting artifact.
11
12
    Discrete "ghost" artifacts may occur along the phase-encode direction
13
    whenever the position or signal intensity of imaged structures within the
14
    field-of-view vary or move in a regular (periodic) fashion. Pulsatile flow
15
    of blood or CSF, cardiac motion, and respiratory motion are the most
16
    important patient-related causes of ghost artifacts in clinical MR imaging
17
    (from `mriquestions.com <http://mriquestions.com/why-discrete-ghosts.html>`_).
18
19
    Args:
20
        num_ghosts: Number of 'ghosts' :math:`n` in the image.
21
            If :py:attr:`num_ghosts` is a tuple :math:`(a, b)`, then
22
            :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
23
            If only one value :math:`d` is provided,
24
            :math:`n \sim \mathcal{U}(0, d) \cap \mathbb{N}`.
25
        axes: Axis along which the ghosts will be created. If
26
            :py:attr:`axes` is a tuple, the axis will be randomly chosen
27
            from the passed values. Anatomical labels may also be used (see
28
            :py:class:`~torchio.transforms.augmentation.RandomFlip`).
29
        intensity: Positive number representing the artifact strength
30
            :math:`s` with respect to the maximum of the :math:`k`-space.
31
            If ``0``, the ghosts will not be visible. If a tuple
32
            :math:`(a, b)` is provided then :math:`s \sim \mathcal{U}(a, b)`.
33
            If only one value :math:`d` is provided,
34
            :math:`s \sim \mathcal{U}(0, d)`.
35
        restore: Number between ``0`` and ``1`` indicating how much of the
36
            :math:`k`-space center should be restored after removing the planes
37
            that generate the artifact.
38
        p: Probability that this transform will be applied.
39
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
40
        keys: See :py:class:`~torchio.transforms.Transform`.
41
42
    .. note:: The execution time of this transform does not depend on the
43
        number of ghosts.
44
45
    .. warning:: Note that height and width of 2D images correspond to axes
46
        ``1`` and ``2`` respectively, as TorchIO images are generally considered
47
        to have 3 spatial dimensions.
48
    """
49
    def __init__(
50
            self,
51
            num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
52
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
53
            intensity: Union[float, Tuple[float, float]] = (0.5, 1),
54
            restore: float = 0.02,
55
            p: float = 1,
56
            seed: Optional[int] = None,
57
            keys: Optional[List[str]] = None,
58
            ):
59
        super().__init__(p=p, seed=seed, keys=keys)
60
        if not isinstance(axes, tuple):
61
            try:
62
                axes = tuple(axes)
63
            except TypeError:
64
                axes = (axes,)
65
        for axis in axes:
66
            if not isinstance(axis, str) and axis not in (0, 1, 2):
67
                raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"')
68
        self.axes = axes
69
        self.num_ghosts_range = self.parse_range(
70
            num_ghosts, 'num_ghosts', min_constraint=0, type_constraint=int)
71
        self.intensity_range = self.parse_range(
72
            intensity, 'intensity_range', min_constraint=0)
73
        self.restore = self.parse_restore(restore)
74
75
    @staticmethod
76
    def parse_restore(restore):
77
        if not isinstance(restore, float):
78
            raise TypeError(f'Restore must be a float, not {restore}')
79
        if not 0 <= restore <= 1:
80
            message = (
81
                f'Restore must be a number between 0 and 1, not {restore}')
82
            raise ValueError(message)
83
        return restore
84
85
    def apply_transform(self, sample: Subject) -> dict:
86
        random_parameters_images_dict = {}
87
        axes_string = False
88
        if any(isinstance(n, str) for n in self.axes):
89
            sample.check_consistent_orientation()
90
            axes_string = True
91
        for image_name, image in sample.get_images_dict().items():
92
            transformed_tensors = []
93
            is_2d = image.is_2d()
94
            axes = [a for a in self.axes if a != 0] if is_2d else self.axes
95
            for channel_idx, tensor in enumerate(image[DATA]):
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
96
                params = self.get_params(
97
                    self.num_ghosts_range,
98
                    axes,
99
                    self.intensity_range,
100
                )
101
                num_ghosts_param, axis_param, intensity_param = params
102
                random_parameters_dict = {
103
                    'axis': axis_param,
104
                    'num_ghosts': num_ghosts_param,
105
                    'intensity': intensity_param,
106
                }
107
                key = f'{image_name}_channel_{channel_idx}'
108
                random_parameters_images_dict[key] = random_parameters_dict
109
                transformed_tensor = self.add_artifact(
110
                    tensor,
111
                    num_ghosts_param,
112
                    axis_param,
113
                    intensity_param,
114
                    self.restore,
115
                )
116
                transformed_tensors.append(transformed_tensor)
117
            image[DATA] = torch.stack(transformed_tensors)
118
        sample.add_transform(self, random_parameters_images_dict)
119
        return sample
120
121
    @staticmethod
122
    def get_params(
123
            num_ghosts_range: Tuple[int, int],
124
            axes: Tuple[int, ...],
125
            intensity_range: Tuple[float, float],
126
            ) -> Tuple:
127
        ng_min, ng_max = num_ghosts_range
128
        num_ghosts = torch.randint(ng_min, ng_max + 1, (1,)).item()
129
        axis = axes[torch.randint(0, len(axes), (1,))]
130
        intensity = torch.FloatTensor(1).uniform_(*intensity_range).item()
131
        return num_ghosts, axis, intensity
132
133
    def add_artifact(
134
            self,
135
            tensor: torch.Tensor,
136
            num_ghosts: int,
137
            axis: int,
138
            intensity: float,
139
            restore_center: float,
140
            ):
141
        if not num_ghosts or not intensity:
142
            return tensor
143
144
        array = tensor.numpy()
145
        spectrum = self.fourier_transform(array)
146
147
        shape = np.array(array.shape)
148
        ri, rj, rk = np.round(restore_center * shape).astype(np.uint16)
149
        mi, mj, mk = np.array(array.shape) // 2
150
151
        # Variable "planes" is the part of the spectrum that will be modified
152
        if axis == 0:
153
            planes = spectrum[::num_ghosts, :, :]
154
            restore = spectrum[mi, :, :].copy()
155
        elif axis == 1:
156
            planes = spectrum[:, ::num_ghosts, :]
157
            restore = spectrum[:, mj, :].copy()
158
        elif axis == 2:
159
            planes = spectrum[:, :, ::num_ghosts]
160
            restore = spectrum[:, :, mk].copy()
161
162
        # Multiply by 0 if intensity is 1
163
        planes *= 1 - intensity
0 ignored issues
show
introduced by
The variable planes does not seem to be defined for all execution paths.
Loading history...
164
165
        # Restore the center of k-space to avoid extreme artifacts
166
        if axis == 0:
167
            spectrum[mi, :, :] = restore
0 ignored issues
show
introduced by
The variable restore does not seem to be defined for all execution paths.
Loading history...
168
        elif axis == 1:
169
            spectrum[:, mj, :] = restore
170
        elif axis == 2:
171
            spectrum[:, :, mk] = restore
172
173
        array_ghosts = self.inv_fourier_transform(spectrum)
174
        array_ghosts = np.real(array_ghosts)
175
        return torch.from_numpy(array_ghosts)
176