Passed
Pull Request — master (#257)
by Fernando
01:05
created

RandomGhosting.__init__()   B

Complexity

Conditions 6

Size

Total Lines 25
Code Lines 24

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 24
nop 8
dl 0
loc 25
rs 8.3706
c 0
b 0
f 0

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from typing import Tuple, Optional, Union, List
2
import torch
3
import numpy as np
4
from ....torchio import DATA
5
from ....data.subject import Subject
6
from .. import RandomTransform
7
8
9
class RandomGhosting(RandomTransform):
10
    r"""Add random MRI ghosting artifact.
11
12
    Discrete "ghost" artifacts may occur along the phase-encode direction
13
    whenever the position or signal intensity of imaged structures within the
14
    field-of-view vary or move in a regular (periodic) fashion. Pulsatile flow
15
    of blood or CSF, cardiac motion, and respiratory motion are the most
16
    important patient-related causes of ghost artifacts in clinical MR imaging
17
    (from `mriquestions.com <http://mriquestions.com/why-discrete-ghosts.html>`_).
18
19
    Args:
20
        num_ghosts: Number of 'ghosts' :math:`n` in the image.
21
            If :py:attr:`num_ghosts` is a tuple :math:`(a, b)`, then
22
            :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
23
            If only one value :math:`d` is provided,
24
            :math:`n \sim \mathcal{U}(0, d) \cap \mathbb{N}`.
25
        axes: Axis along which the ghosts will be created. If
26
            :py:attr:`axes` is a tuple, the axis will be randomly chosen
27
            from the passed values. Anatomical labels may also be used (see
28
            :py:class:`~torchio.transforms.augmentation.RandomFlip`).
29
        intensity: Positive number representing the artifact strength
30
            :math:`s` with respect to the maximum of the :math:`k`-space.
31
            If ``0``, the ghosts will not be visible. If a tuple
32
            :math:`(a, b)` is provided then :math:`s \sim \mathcal{U}(a, b)`.
33
            If only one value :math:`d` is provided,
34
            :math:`s \sim \mathcal{U}(0, d)`.
35
        restore: Number between ``0`` and ``1`` indicating how much of the
36
            :math:`k`-space center should be restored after removing the planes
37
            that generate the artifact.
38
        p: Probability that this transform will be applied.
39
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
40
        keys: See :py:class:`~torchio.transforms.Transform`.
41
42
    .. note:: The execution time of this transform does not depend on the
43
        number of ghosts.
44
45
    .. warning:: Note that height and width of 2D images correspond to axes
46
        ``1`` and ``2`` respectively, as TorchIO images are generally considered
47
        to have 3 spatial dimensions.
48
    """
49
    def __init__(
50
            self,
51
            num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
52
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
53
            intensity: Union[float, Tuple[float, float]] = (0.5, 1),
54
            restore: float = 0.02,
55
            p: float = 1,
56
            seed: Optional[int] = None,
57
            keys: Optional[List[str]] = None,
58
            ):
59
        super().__init__(p=p, seed=seed, keys=keys)
60
        if not isinstance(axes, tuple):
61
            try:
62
                axes = tuple(axes)
63
            except TypeError:
64
                axes = (axes,)
65
        for axis in axes:
66
            if not isinstance(axis, str) and axis not in (0, 1, 2):
67
                raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"')
68
        self.axes = axes
69
        self.num_ghosts_range = self.parse_range(
70
            num_ghosts, 'num_ghosts', min_constraint=0, type_constraint=int)
71
        self.intensity_range = self.parse_range(
72
            intensity, 'intensity_range', min_constraint=0)
73
        self.restore = self.parse_restore(restore)
74
75
    @staticmethod
76
    def parse_restore(restore):
77
        if not isinstance(restore, float):
78
            raise TypeError(f'Restore must be a float, not {restore}')
79
        if not 0 <= restore <= 1:
80
            message = (
81
                f'Restore must be a number between 0 and 1, not {restore}')
82
            raise ValueError(message)
83
        return restore
84
85
    def apply_transform(self, sample: Subject) -> dict:
86
        random_parameters_images_dict = {}
87
        axes_string = False
88
        if any(isinstance(n, str) for n in self.axes):
89
            sample.check_consistent_orientation()
90
            axes_string = True
91
        for image_name, image in sample.get_images_dict().items():
92
            transformed_tensors = []
93
            is_2d = image.is_2d()
94
            axes = [a for a in self.axes if a != 0] if is_2d else self.axes
95
            for channel_idx, tensor in enumerate(image[DATA]):
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
96
                params = self.get_params(
97
                    self.num_ghosts_range,
98
                    axes,
99
                    self.intensity_range,
100
                )
101
                num_ghosts_param, axis_param, intensity_param = params
102
                random_parameters_dict = {
103
                    'axis': axis_param,
104
                    'num_ghosts': num_ghosts_param,
105
                    'intensity': intensity_param,
106
                }
107
                key = f'{image_name}_channel_{channel_idx}'
108
                random_parameters_images_dict[key] = random_parameters_dict
109
                transformed_tensor = self.add_artifact(
110
                    tensor,
111
                    num_ghosts_param,
112
                    axis_param,
113
                    intensity_param,
114
                    self.restore,
115
                )
116
                transformed_tensors.append(transformed_tensor)
117
            image[DATA] = torch.stack(transformed_tensors)
118
        sample.add_transform(self, random_parameters_images_dict)
119
        return sample
120
121
    @staticmethod
122
    def get_params(
123
            num_ghosts_range: Tuple[int, int],
124
            axes: Tuple[int, ...],
125
            intensity_range: Tuple[float, float],
126
            ) -> Tuple:
127
        ng_min, ng_max = num_ghosts_range
128
        num_ghosts = torch.randint(ng_min, ng_max + 1, (1,)).item()
129
        axis = axes[torch.randint(0, len(axes), (1,))]
130
        intensity = torch.FloatTensor(1).uniform_(*intensity_range).item()
131
        return num_ghosts, axis, intensity
132
133
    def add_artifact(
134
            self,
135
            tensor: torch.Tensor,
136
            num_ghosts: int,
137
            axis: int,
138
            intensity: float,
139
            restore_center: float,
140
            ):
141
        if not num_ghosts or not intensity:
142
            return tensor
143
144
        array = tensor.numpy()
145
        spectrum = self.fourier_transform(array)
146
147
        shape = np.array(array.shape)
148
        ri, rj, rk = np.round(restore_center * shape).astype(np.uint16)
149
        mi, mj, mk = np.array(array.shape) // 2
150
151
        # Variable "planes" is the part of the spectrum that will be modified
152
        if axis == 0:
153
            planes = spectrum[::num_ghosts, :, :]
154
            restore = spectrum[mi, :, :].copy()
155
        elif axis == 1:
156
            planes = spectrum[:, ::num_ghosts, :]
157
            restore = spectrum[:, mj, :].copy()
158
        elif axis == 2:
159
            planes = spectrum[:, :, ::num_ghosts]
160
            restore = spectrum[:, :, mk].copy()
161
162
        # Multiply by 0 if intensity is 1
163
        planes *= 1 - intensity
0 ignored issues
show
introduced by
The variable planes does not seem to be defined for all execution paths.
Loading history...
164
165
        # Restore the center of k-space to avoid extreme artifacts
166
        if axis == 0:
167
            spectrum[mi, :, :] = restore
0 ignored issues
show
introduced by
The variable restore does not seem to be defined for all execution paths.
Loading history...
168
        elif axis == 1:
169
            spectrum[:, mj, :] = restore
170
        elif axis == 2:
171
            spectrum[:, :, mk] = restore
172
173
        array_ghosts = self.inv_fourier_transform(spectrum)
174
        array_ghosts = np.real(array_ghosts)
175
        return torch.from_numpy(array_ghosts)
176