Passed
Push — master ( a7650a...b61726 )
by Fernando
01:21
created

RandomGhosting.get_slice()   A

Complexity

Conditions 4

Size

Total Lines 12
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 10
nop 3
dl 0
loc 12
rs 9.9
c 0
b 0
f 0
1
import warnings
2
from typing import Tuple, Optional, Union
3
import torch
0 ignored issues
show
introduced by
Unable to import 'torch'
Loading history...
4
import numpy as np
0 ignored issues
show
introduced by
Unable to import 'numpy'
Loading history...
5
import SimpleITK as sitk
0 ignored issues
show
introduced by
Unable to import 'SimpleITK'
Loading history...
6
from ....torchio import DATA, AFFINE
7
from ....data.subject import Subject
8
from .. import RandomTransform
9
10
11
class RandomGhosting(RandomTransform):
12
    r"""Add random MRI ghosting artifact.
13
14
    Args:
15
        num_ghosts: Number of 'ghosts' :math:`n` in the image.
16
            If :py:attr:`num_ghosts` is a tuple :math:`(a, b)`, then
17
            :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
18
        axes: Axis along which the ghosts will be created. If
19
            :py:attr:`axes` is a tuple, the axis will be randomly chosen
20
            from the passed values.
21
        p: Probability that this transform will be applied.
22
        seed: See :py:class:`~torchio.transforms.augmentation.RandomTransform`.
23
24
    .. note:: The execution time of this transform does not depend on the
25
        number of ghosts.
26
    """
27
    def __init__(
28
            self,
29
            num_ghosts: Union[int, Tuple[int, int]] = (4, 10),
30
            axes: Union[int, Tuple[int, ...]] = (0, 1, 2),
31
            p: float = 1,
32
            seed: Optional[int] = None,
33
            ):
34
        super().__init__(p=p, seed=seed)
35
        if not isinstance(axes, tuple):
36
            axes = (axes,)
37
        for axis in axes:
38
            if axis not in (0, 1, 2):
39
                raise ValueError(f'Axes must be in (0, 1, 2), not "{axes}"')
40
        self.axes = axes
41
        if isinstance(num_ghosts, int):
42
            self.num_ghosts_range = num_ghosts, num_ghosts
43
        elif isinstance(num_ghosts, tuple) and len(num_ghosts) == 2:
44
            self.num_ghosts_range = num_ghosts
45
46 View Code Duplication
    def apply_transform(self, sample: Subject) -> dict:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
47
        random_parameters_images_dict = {}
48
        for image_name, image_dict in sample.get_images_dict().items():
49
            params = self.get_params(
50
                self.num_ghosts_range,
51
                self.axes,
52
            )
53
            num_ghosts_param, axis_param = params
54
            random_parameters_dict = {
55
                'axis': axis_param,
56
                'num_ghosts': num_ghosts_param,
57
            }
58
            random_parameters_images_dict[image_name] = random_parameters_dict
59
            if (image_dict[DATA][0] < -0.1).any():
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
60
                # I use -0.1 instead of 0 because Python was warning me when
61
                # a value in a voxel was -7.191084e-35
62
                # There must be a better way of solving this
63
                message = (
64
                    f'Image "{image_name}" from "{image_dict["stem"]}"'
65
                    ' has negative values.'
66
                    ' Results can be unexpected because the transformed sample'
67
                    ' is computed as the absolute values'
68
                    ' of an inverse Fourier transform'
69
                )
70
                warnings.warn(message)
71
            image = self.nib_to_sitk(
72
                image_dict[DATA][0],
73
                image_dict[AFFINE],
74
            )
75
            image_dict[DATA] = self.add_artifact(
76
                image,
77
                num_ghosts_param,
78
                axis_param,
79
            )
80
            # Add channels dimension
81
            image_dict[DATA] = image_dict[DATA][np.newaxis, ...]
82
            image_dict[DATA] = torch.from_numpy(image_dict[DATA])
83
        sample.add_transform(self, random_parameters_images_dict)
84
        return sample
85
86
    @staticmethod
87
    def get_params(
88
            num_ghosts_range: Tuple[int, int],
89
            axes: Tuple[int, ...],
90
            ) -> Tuple:
91
        ng_min, ng_max = num_ghosts_range
92
        num_ghosts_param = torch.randint(ng_min, ng_max + 1, (1,)).item()
93
        axis_param = axes[torch.randint(0, len(axes), (1,))]
94
        return num_ghosts_param, axis_param
95
96
    @staticmethod
97
    def get_axis_and_size(axis, array):
98
        if axis == 1:
99
            axis = 0
100
            size = array.shape[0]
101
        elif axis == 0:
102
            axis = 1
103
            size = array.shape[1]
104
        elif axis == 2:  # we will also traverse in sagittal (if RAS)
105
            size = array.shape[0]
106
        else:
107
            raise RuntimeError(f'Axis "{axis}" is not valid')
108
        return axis, size
109
110
    @staticmethod
111
    def get_slice(axis, array, slice_idx):
112
        # Comments apply if RAS
113
        if axis == 0:  # sagittal (columns) - artifact AP
114
            image_slice = array[slice_idx, ...]
115
        elif axis == 1:  # coronal (columns) - artifact LR
116
            image_slice = array[:, slice_idx, :]
117
        elif axis == 2:  # sagittal (rows) - artifact IS
118
            image_slice = array[slice_idx, ...].T
119
        else:
120
            raise RuntimeError(f'Axis "{axis}" is not valid')
121
        return image_slice
122
123
    def add_artifact(
124
            self,
125
            image: sitk.Image,
126
            num_ghosts: int,
127
            axis: int,
128
            ):
129
        array = sitk.GetArrayFromImage(image).transpose()
130
        # Leave first 5% of frequencies untouched. If the image is in RAS
131
        # orientation, this helps applying the ghosting in the desired axis
132
        # intuitively
133
        # [Why? I forgot]
134
        percentage_to_avoid = 0.05
135
        axis, size = self.get_axis_and_size(axis, array)
136
        for slice_idx in range(size):
137
            image_slice = self.get_slice(axis, array, slice_idx)
138
            spectrum = self.fourier_transform(image_slice)
139
            for row_idx, row in enumerate(spectrum):
140
                if row_idx % num_ghosts:
141
                    continue
142
                progress = row_idx / array.shape[0]
143
                if np.abs(progress - 0.5) < percentage_to_avoid / 2:
144
                    continue
145
                row *= 0
146
            image_slice *= 0
147
            image_slice += self.inv_fourier_transform(spectrum)
148
        return array
149