Passed
Push — master ( bf11d4...8cedd2 )
by Fernando
01:04
created

RandomGhosting.__init__()   C

Complexity

Conditions 10

Size

Total Lines 28
Code Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 10
eloc 27
nop 6
dl 0
loc 28
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like torchio.transforms.augmentation.intensity.random_ghosting.RandomGhosting.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import warnings
2
from typing import Tuple, Optional, Union
3
import torch
4
import numpy as np
5
import SimpleITK as sitk
6
from ....torchio import DATA, AFFINE
7
from ....data.subject import Subject
8
from .. import RandomTransform
9
10
11
class RandomGhosting(RandomTransform):
12
    r"""Add random MRI ghosting artifact.
13
14
    Args:
15
        num_ghosts: Number of 'ghosts' :math:`n` in the image.
16
            If :py:attr:`num_ghosts` is a tuple :math:`(a, b)`, then
17
            :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
18
        axes: Axis along which the ghosts will be created. If
19
            :py:attr:`axes` is a tuple, the axis will be randomly chosen
20
            from the passed values.
21
        intensity: Number between 0 and 1 representing the artifact strength
22
            :math:`s`. If ``0``, the ghosts will not be visible. If a tuple
23
            :math:`(a, b)`, is provided then
24
            :math:`s \sim \mathcal{U}(a, b)`.
25
        p: Probability that this transform will be applied.
26
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
27
28
    .. note:: The execution time of this transform does not depend on the
29
        number of ghosts.
30
    """
31
    def __init__(
32
            self,
33
            num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
34
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
35
            intensity: Union[float, Tuple[float, float]] = (0.5, 1),
36
            p: float = 1,
37
            seed: Optional[int] = None,
38
            ):
39
        super().__init__(p=p, seed=seed)
40
        if not isinstance(axes, tuple):
41
            try:
42
                axes = tuple(axes)
43
            except TypeError:
44
                axes = (axes,)
45
        for axis in axes:
46
            if axis not in (0, 1, 2):
47
                raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"')
48
        self.axes = axes
49
        if isinstance(num_ghosts, int):
50
            self.num_ghosts_range = num_ghosts, num_ghosts
51
        elif isinstance(num_ghosts, tuple) and len(num_ghosts) == 2:
52
            self.num_ghosts_range = num_ghosts
53
        self.intensity_range = self.parse_range(intensity, 'intensity')
54
        for n in self.intensity_range:
55
            if not 0 <= n <= 1:
56
                message = (
57
                    'Intensity must be a number between 0 and 1, not {n}')
58
                raise ValueError(message)
59
60
    def apply_transform(self, sample: Subject) -> dict:
61
        random_parameters_images_dict = {}
62
        for image_name, image_dict in sample.get_images_dict().items():
63
            data = image_dict[DATA]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
64
            is_2d = data.shape[-3] == 1
65
            axes = [a for a in self.axes if a != 0] if is_2d else self.axes
66
            params = self.get_params(
67
                self.num_ghosts_range,
68
                axes,
69
                self.intensity_range,
70
            )
71
            num_ghosts_param, axis_param, intensity_param = params
72
            random_parameters_dict = {
73
                'axis': axis_param,
74
                'num_ghosts': num_ghosts_param,
75
                'intensity': intensity_param,
76
            }
77
            random_parameters_images_dict[image_name] = random_parameters_dict
78
            if (data[0] < -0.1).any():
79
                # I use -0.1 instead of 0 because Python was warning me when
80
                # a value in a voxel was -7.191084e-35
81
                # There must be a better way of solving this
82
                message = (
83
                    f'Image "{image_name}" from "{image_dict["stem"]}"'
84
                    ' has negative values.'
85
                    ' Results can be unexpected because the transformed sample'
86
                    ' is computed as the absolute values'
87
                    ' of an inverse Fourier transform'
88
                )
89
                warnings.warn(message)
90
            image = self.nib_to_sitk(
91
                data[0],
92
                image_dict[AFFINE],
93
            )
94
            data = self.add_artifact(
95
                image,
96
                num_ghosts_param,
97
                axis_param,
98
                intensity_param,
99
            )
100
            # Add channels dimension
101
            data = data[np.newaxis, ...]
102
            image_dict[DATA] = torch.from_numpy(data)
103
        sample.add_transform(self, random_parameters_images_dict)
104
        return sample
105
106
    @staticmethod
107
    def get_params(
108
            num_ghosts_range: Tuple[int, int],
109
            axes: Tuple[int, ...],
110
            intensity_range: Tuple[float, float],
111
            ) -> Tuple:
112
        ng_min, ng_max = num_ghosts_range
113
        num_ghosts = torch.randint(ng_min, ng_max + 1, (1,)).item()
114
        axis = axes[torch.randint(0, len(axes), (1,))]
115
        intensity = torch.FloatTensor(1).uniform_(*intensity_range).item()
116
        return num_ghosts, axis, intensity
117
118
    @staticmethod
119
    def get_axis_and_size(axis, array):
120
        if axis == 1:
121
            axis = 0
122
            size = array.shape[0]
123
        elif axis == 0:
124
            axis = 1
125
            size = array.shape[1]
126
        elif axis == 2:  # we will also traverse in sagittal (if RAS)
127
            size = array.shape[0]
128
        else:
129
            raise RuntimeError(f'Axis "{axis}" is not valid')
130
        return axis, size
131
132
    @staticmethod
133
    def get_slice(axis, array, slice_idx):
134
        # Comments apply if RAS
135
        if axis == 0:  # sagittal (columns) - artifact AP
136
            image_slice = array[slice_idx, ...]
137
        elif axis == 1:  # coronal (columns) - artifact LR
138
            image_slice = array[:, slice_idx, :]
139
        elif axis == 2:  # sagittal (rows) - artifact IS
140
            image_slice = array[slice_idx, ...].T
141
        else:
142
            raise RuntimeError(f'Axis "{axis}" is not valid')
143
        return image_slice
144
145
    def add_artifact(
146
            self,
147
            image: sitk.Image,
148
            num_ghosts: int,
149
            axis: int,
150
            intensity: float,
151
            ):
152
        array = sitk.GetArrayFromImage(image).transpose()
153
        # Leave first 5% of frequencies untouched. If the image is in RAS
154
        # orientation, this helps applying the ghosting in the desired axis
155
        # intuitively
156
        # [Why? I forgot]
157
        percentage_to_avoid = 0.05
158
        axis, size = self.get_axis_and_size(axis, array)
159
        for slice_idx in range(size):
160
            image_slice = self.get_slice(axis, array, slice_idx)
161
            spectrum = self.fourier_transform(image_slice)
162
            for row_idx, row in enumerate(spectrum):
163
                if row_idx % num_ghosts:
164
                    continue
165
                progress = row_idx / array.shape[0]
166
                if np.abs(progress - 0.5) < percentage_to_avoid / 2:
167
                    continue
168
                row *= 1 - intensity
169
            image_slice *= 0
170
            image_slice += self.inv_fourier_transform(spectrum)
171
        return array
172