Passed
Pull Request — master (#219)
by Fernando
01:51
created

RandomGhosting.add_artifact()   B

Complexity

Conditions 7

Size

Total Lines 39
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 7
eloc 30
nop 6
dl 0
loc 39
rs 7.76
c 0
b 0
f 0
1
from typing import Tuple, Optional, Union
2
import torch
3
import numpy as np
4
import SimpleITK as sitk
5
from ....torchio import DATA, AFFINE
6
from ....data.subject import Subject
7
from .. import RandomTransform
8
9
10
class RandomGhosting(RandomTransform):
11
    r"""Add random MRI ghosting artifact.
12
13
    Args:
14
        num_ghosts: Number of 'ghosts' :math:`n` in the image.
15
            If :py:attr:`num_ghosts` is a tuple :math:`(a, b)`, then
16
            :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
17
        axes: Axis along which the ghosts will be created. If
18
            :py:attr:`axes` is a tuple, the axis will be randomly chosen
19
            from the passed values.
20
        intensity: Positive number representing the artifact strength
21
            :math:`s` with respect to the maximum of the :math:`k`-space.
22
            If ``0``, the ghosts will not be visible. If a tuple
23
            :math:`(a, b)`, is provided then
24
            :math:`s \sim \mathcal{U}(a, b)`.
25
        restore: Number between ``0`` and ``1`` indicating how much of the
26
            :math:`k`-space center should be restored after removing the planes
27
            that generate the artifact.
28
        p: Probability that this transform will be applied.
29
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
30
31
    .. note:: The execution time of this transform does not depend on the
32
        number of ghosts.
33
    """
34
    def __init__(
35
            self,
36
            num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
37
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
38
            intensity: Union[float, Tuple[float, float]] = (0.5, 1),
39
            restore: float = 0.02,
40
            p: float = 1,
41
            seed: Optional[int] = None,
42
            ):
43
        super().__init__(p=p, seed=seed)
44
        if not isinstance(axes, tuple):
45
            try:
46
                axes = tuple(axes)
47
            except TypeError:
48
                axes = (axes,)
49
        for axis in axes:
50
            if axis not in (0, 1, 2):
51
                raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"')
52
        self.axes = axes
53
        self.num_ghosts_range = self.parse_num_ghosts(num_ghosts)
54
        self.intensity_range = self.parse_intensity(intensity)
55
        if not 0 <= restore < 1:
56
            message = (
57
                f'Restore must be a number between 0 and 1, not {restore}')
58
            raise ValueError(message)
59
        self.restore = restore
60
61
    @staticmethod
62
    def parse_num_ghosts(num_ghosts):
63
        try:
64
            iter(num_ghosts)
65
        except TypeError:
66
            num_ghosts = num_ghosts, num_ghosts
67
        for n in num_ghosts:
68
            if not isinstance(n, int) or n < 0:
69
                message = (
70
                    f'Number of ghosts must be a natural number, not {n}')
71
                raise ValueError(message)
72
        return num_ghosts
73
74
    @staticmethod
75
    def parse_intensity(intensity):
76
        try:
77
            iter(intensity)
78
        except TypeError:
79
            intensity = intensity, intensity
80
        for n in intensity:
81
            if n < 0:
82
                message = (
83
                    f'Intensity must be a positive number, not {n}')
84
                raise ValueError(message)
85
        return intensity
86
87
    def apply_transform(self, sample: Subject) -> dict:
88
        random_parameters_images_dict = {}
89
        for image_name, image_dict in sample.get_images_dict().items():
90
            data = image_dict[DATA]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
91
            is_2d = data.shape[-3] == 1
92
            axes = [a for a in self.axes if a != 0] if is_2d else self.axes
93
            params = self.get_params(
94
                self.num_ghosts_range,
95
                axes,
96
                self.intensity_range,
97
            )
98
            num_ghosts_param, axis_param, intensity_param = params
99
            random_parameters_dict = {
100
                'axis': axis_param,
101
                'num_ghosts': num_ghosts_param,
102
                'intensity': intensity_param,
103
            }
104
            random_parameters_images_dict[image_name] = random_parameters_dict
105
            image_dict[DATA][0] = self.add_artifact(
106
                data[0],
107
                num_ghosts_param,
108
                axis_param,
109
                intensity_param,
110
                self.restore,
111
            )
112
        sample.add_transform(self, random_parameters_images_dict)
113
        return sample
114
115
    @staticmethod
116
    def get_params(
117
            num_ghosts_range: Tuple[int, int],
118
            axes: Tuple[int, ...],
119
            intensity_range: Tuple[float, float],
120
            ) -> Tuple:
121
        ng_min, ng_max = num_ghosts_range
122
        num_ghosts = torch.randint(ng_min, ng_max + 1, (1,)).item()
123
        axis = axes[torch.randint(0, len(axes), (1,))]
124
        intensity = torch.FloatTensor(1).uniform_(*intensity_range).item()
125
        return num_ghosts, axis, intensity
126
127
    def add_artifact(
128
            self,
129
            tensor: torch.Tensor,
130
            num_ghosts: int,
131
            axis: int,
132
            intensity: float,
133
            restore_center: float,
134
            ):
135
        array = tensor.numpy()
136
        spectrum = self.fourier_transform(array)
137
138
        ri, rj, rk = np.round(restore_center * np.array(array.shape)).astype(np.uint16)
139
        mi, mj, mk = np.array(array.shape) // 2
140
141
        # Variable "planes" is the part the spectrum that will be modified
142
        if axis == 0:
143
            planes = spectrum[::num_ghosts, :, :]
144
            restore = spectrum[mi, :, :].copy()
145
        elif axis == 1:
146
            planes = spectrum[:, ::num_ghosts, :]
147
            restore = spectrum[:, mj, :].copy()
148
        elif axis == 2:
149
            planes = spectrum[:, :, ::num_ghosts]
150
            restore = spectrum[:, :, mk].copy()
151
152
        # Multiply by 0 if intensity is 1
153
        planes *= 1 - intensity
0 ignored issues
show
introduced by
The variable planes does not seem to be defined for all execution paths.
Loading history...
154
155
        # Restore the center of k-space to avoid extreme artifacts
156
        if axis == 0:
157
            spectrum[mi, :, :] = restore
0 ignored issues
show
introduced by
The variable restore does not seem to be defined for all execution paths.
Loading history...
158
        elif axis == 1:
159
            spectrum[:, mj, :] = restore
160
        elif axis == 2:
161
            spectrum[:, :, mk] = restore
162
163
        array_ghosts = self.inv_fourier_transform(spectrum)
164
        array_ghosts = np.real(array_ghosts)
165
        return torch.from_numpy(array_ghosts)
166