Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

RandomGhosting.__init__()   B

Complexity

Conditions 7

Size

Total Lines 18
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 7
eloc 17
nop 5
dl 0
loc 18
rs 8
c 0
b 0
f 0
1
import warnings
2
from typing import Tuple, Optional, Union
3
import torch
4
import numpy as np
5
import SimpleITK as sitk
6
from ....torchio import DATA, AFFINE
7
from ....data.subject import Subject
8
from .. import RandomTransform
9
10
11
class RandomGhosting(RandomTransform):
12
    r"""Add random MRI ghosting artifact.
13
14
    Args:
15
        num_ghosts: Number of 'ghosts' :math:`n` in the image.
16
            If :py:attr:`num_ghosts` is a tuple :math:`(a, b)`, then
17
            :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
18
        axes: Axis along which the ghosts will be created. If
19
            :py:attr:`axes` is a tuple, the axis will be randomly chosen
20
            from the passed values.
21
        p: Probability that this transform will be applied.
22
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
23
24
    .. note:: The execution time of this transform does not depend on the
25
        number of ghosts.
26
    """
27
    def __init__(
28
            self,
29
            num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
30
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
31
            p: float = 1,
32
            seed: Optional[int] = None,
33
            ):
34
        super().__init__(p=p, seed=seed)
35
        if not isinstance(axes, tuple):
36
            axes = (axes,)
37
        for axis in axes:
38
            if axis not in (0, 1, 2):
39
                raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"')
40
        self.axes = axes
41
        if isinstance(num_ghosts, int):
42
            self.num_ghosts_range = num_ghosts, num_ghosts
43
        elif isinstance(num_ghosts, tuple) and len(num_ghosts) == 2:
44
            self.num_ghosts_range = num_ghosts
45
46
    def apply_transform(self, sample: Subject) -> dict:
47
        random_parameters_images_dict = {}
48
        for image_name, image_dict in sample.get_images_dict().items():
49
            data = image_dict[DATA]
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
50
            is_2d = data.shape[-3] == 1
51
            axes = [a for a in self.axes if a != 0] if is_2d else self.axes
52
            params = self.get_params(
53
                self.num_ghosts_range,
54
                axes,
55
            )
56
            num_ghosts_param, axis_param = params
57
            random_parameters_dict = {
58
                'axis': axis_param,
59
                'num_ghosts': num_ghosts_param,
60
            }
61
            random_parameters_images_dict[image_name] = random_parameters_dict
62
            if (data[0] < -0.1).any():
63
                # I use -0.1 instead of 0 because Python was warning me when
64
                # a value in a voxel was -7.191084e-35
65
                # There must be a better way of solving this
66
                message = (
67
                    f'Image "{image_name}" from "{image_dict["stem"]}"'
68
                    ' has negative values.'
69
                    ' Results can be unexpected because the transformed sample'
70
                    ' is computed as the absolute values'
71
                    ' of an inverse Fourier transform'
72
                )
73
                warnings.warn(message)
74
            image = self.nib_to_sitk(
75
                data[0],
76
                image_dict[AFFINE],
77
            )
78
            data = self.add_artifact(
79
                image,
80
                num_ghosts_param,
81
                axis_param,
82
            )
83
            # Add channels dimension
84
            data = data[np.newaxis, ...]
85
            image_dict[DATA] = torch.from_numpy(data)
86
        sample.add_transform(self, random_parameters_images_dict)
87
        return sample
88
89
    @staticmethod
90
    def get_params(
91
            num_ghosts_range: Tuple[int, int],
92
            axes: Tuple[int, ...],
93
            ) -> Tuple:
94
        ng_min, ng_max = num_ghosts_range
95
        num_ghosts_param = torch.randint(ng_min, ng_max + 1, (1,)).item()
96
        axis_param = axes[torch.randint(0, len(axes), (1,))]
97
        return num_ghosts_param, axis_param
98
99
    @staticmethod
100
    def get_axis_and_size(axis, array):
101
        if axis == 1:
102
            axis = 0
103
            size = array.shape[0]
104
        elif axis == 0:
105
            axis = 1
106
            size = array.shape[1]
107
        elif axis == 2:  # we will also traverse in sagittal (if RAS)
108
            size = array.shape[0]
109
        else:
110
            raise RuntimeError(f'Axis "{axis}" is not valid')
111
        return axis, size
112
113
    @staticmethod
114
    def get_slice(axis, array, slice_idx):
115
        # Comments apply if RAS
116
        if axis == 0:  # sagittal (columns) - artifact AP
117
            image_slice = array[slice_idx, ...]
118
        elif axis == 1:  # coronal (columns) - artifact LR
119
            image_slice = array[:, slice_idx, :]
120
        elif axis == 2:  # sagittal (rows) - artifact IS
121
            image_slice = array[slice_idx, ...].T
122
        else:
123
            raise RuntimeError(f'Axis "{axis}" is not valid')
124
        return image_slice
125
126
    def add_artifact(
127
            self,
128
            image: sitk.Image,
129
            num_ghosts: int,
130
            axis: int,
131
            ):
132
        array = sitk.GetArrayFromImage(image).transpose()
133
        # Leave first 5% of frequencies untouched. If the image is in RAS
134
        # orientation, this helps applying the ghosting in the desired axis
135
        # intuitively
136
        # [Why? I forgot]
137
        percentage_to_avoid = 0.05
138
        axis, size = self.get_axis_and_size(axis, array)
139
        for slice_idx in range(size):
140
            image_slice = self.get_slice(axis, array, slice_idx)
141
            spectrum = self.fourier_transform(image_slice)
142
            for row_idx, row in enumerate(spectrum):
143
                if row_idx % num_ghosts:
144
                    continue
145
                progress = row_idx / array.shape[0]
146
                if np.abs(progress - 0.5) < percentage_to_avoid / 2:
147
                    continue
148
                row *= 0
149
            image_slice *= 0
150
            image_slice += self.inv_fourier_transform(spectrum)
151
        return array
152