Passed
Push — master ( c291a8...879ee9 )
by Fernando
59s
created

RandomGhosting.apply_transform()   A

Complexity

Conditions 3

Size

Total Lines 27
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 24
nop 2
dl 0
loc 27
rs 9.304
c 0
b 0
f 0
1
from typing import Tuple, Optional, Union
2
import torch
3
import numpy as np
4
from ....torchio import DATA
5
from ....data.subject import Subject
6
from .. import RandomTransform
7
8
9
class RandomGhosting(RandomTransform):
10
    r"""Add random MRI ghosting artifact.
11
12
    Args:
13
        num_ghosts: Number of 'ghosts' :math:`n` in the image.
14
            If :py:attr:`num_ghosts` is a tuple :math:`(a, b)`, then
15
            :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
16
            If only one value :math:`d` is provided,
17
            :math:`\n \sim \mathcal{U}(0, d) \cap \mathbb{N}`.
18
        axes: Axis along which the ghosts will be created. If
19
            :py:attr:`axes` is a tuple, the axis will be randomly chosen
20
            from the passed values.
21
        intensity: Positive number representing the artifact strength
22
            :math:`s` with respect to the maximum of the :math:`k`-space.
23
            If ``0``, the ghosts will not be visible. If a tuple
24
            :math:`(a, b)` is provided then :math:`s \sim \mathcal{U}(a, b)`.
25
            If only one value :math:`d` is provided,
26
            :math:`\s \sim \mathcal{U}(0, d)`.
27
        restore: Number between ``0`` and ``1`` indicating how much of the
28
            :math:`k`-space center should be restored after removing the planes
29
            that generate the artifact.
30
        p: Probability that this transform will be applied.
31
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
32
33
    .. note:: The execution time of this transform does not depend on the
34
        number of ghosts.
35
    """
36
    def __init__(
37
            self,
38
            num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
39
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
40
            intensity: Union[float, Tuple[float, float]] = (0.5, 1),
41
            restore: float = 0.02,
42
            p: float = 1,
43
            seed: Optional[int] = None,
44
            ):
45
        super().__init__(p=p, seed=seed)
46
        if not isinstance(axes, tuple):
47
            try:
48
                axes = tuple(axes)
49
            except TypeError:
50
                axes = (axes,)
51
        for axis in axes:
52
            if axis not in (0, 1, 2):
53
                raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"')
54
        self.axes = axes
55
        self.num_ghosts_range = self.parse_range(
56
            num_ghosts, 'num_ghosts', min_constraint=0, type_constraint=int)
57
        self.intensity_range = self.parse_range(
58
            intensity, 'intensity_range', min_constraint=0)
59
        self.restore = self.parse_restore(restore)
60
61
    @staticmethod
62
    def parse_restore(restore):
63
        if not isinstance(restore, float):
64
            raise TypeError(f'Restore must be a float, not {restore}')
65
        if not 0 <= restore <= 1:
66
            message = (
67
                f'Restore must be a number between 0 and 1, not {restore}')
68
            raise ValueError(message)
69
        return restore
70
71
    def apply_transform(self, sample: Subject) -> dict:
72
        random_parameters_images_dict = {}
73
        for image_name, image_dict in sample.get_images_dict().items():
74
            data = image_dict[DATA]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
75
            is_2d = data.shape[-3] == 1
76
            axes = [a for a in self.axes if a != 0] if is_2d else self.axes
77
            params = self.get_params(
78
                self.num_ghosts_range,
79
                axes,
80
                self.intensity_range,
81
            )
82
            num_ghosts_param, axis_param, intensity_param = params
83
            random_parameters_dict = {
84
                'axis': axis_param,
85
                'num_ghosts': num_ghosts_param,
86
                'intensity': intensity_param,
87
            }
88
            random_parameters_images_dict[image_name] = random_parameters_dict
89
            image_dict[DATA][0] = self.add_artifact(
90
                data[0],
91
                num_ghosts_param,
92
                axis_param,
93
                intensity_param,
94
                self.restore,
95
            )
96
        sample.add_transform(self, random_parameters_images_dict)
97
        return sample
98
99
    @staticmethod
100
    def get_params(
101
            num_ghosts_range: Tuple[int, int],
102
            axes: Tuple[int, ...],
103
            intensity_range: Tuple[float, float],
104
            ) -> Tuple:
105
        ng_min, ng_max = num_ghosts_range
106
        num_ghosts = torch.randint(ng_min, ng_max + 1, (1,)).item()
107
        axis = axes[torch.randint(0, len(axes), (1,))]
108
        intensity = torch.FloatTensor(1).uniform_(*intensity_range).item()
109
        return num_ghosts, axis, intensity
110
111
    def add_artifact(
112
            self,
113
            tensor: torch.Tensor,
114
            num_ghosts: int,
115
            axis: int,
116
            intensity: float,
117
            restore_center: float,
118
            ):
119
        if num_ghosts == 0:
120
            return tensor
121
122
        array = tensor.numpy()
123
        spectrum = self.fourier_transform(array)
124
125
        ri, rj, rk = np.round(restore_center * np.array(array.shape)).astype(np.uint16)
126
        mi, mj, mk = np.array(array.shape) // 2
127
128
        # Variable "planes" is the part the spectrum that will be modified
129
        if axis == 0:
130
            planes = spectrum[::num_ghosts, :, :]
131
            restore = spectrum[mi, :, :].copy()
132
        elif axis == 1:
133
            planes = spectrum[:, ::num_ghosts, :]
134
            restore = spectrum[:, mj, :].copy()
135
        elif axis == 2:
136
            planes = spectrum[:, :, ::num_ghosts]
137
            restore = spectrum[:, :, mk].copy()
138
139
        # Multiply by 0 if intensity is 1
140
        planes *= 1 - intensity
0 ignored issues
show
introduced by
The variable planes does not seem to be defined for all execution paths.
Loading history...
141
142
        # Restore the center of k-space to avoid extreme artifacts
143
        if axis == 0:
144
            spectrum[mi, :, :] = restore
0 ignored issues
show
introduced by
The variable restore does not seem to be defined for all execution paths.
Loading history...
145
        elif axis == 1:
146
            spectrum[:, mj, :] = restore
147
        elif axis == 2:
148
            spectrum[:, :, mk] = restore
149
150
        array_ghosts = self.inv_fourier_transform(spectrum)
151
        array_ghosts = np.real(array_ghosts)
152
        return torch.from_numpy(array_ghosts)
153