Passed
Pull Request — master (#389)
by Fernando
01:21
created

torchio.transforms.augmentation.intensity.random_spike   A

Complexity

Total Complexity 14

Size/Duplication

Total Lines 152
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 87
dl 0
loc 152
rs 10
c 0
b 0
f 0
wmc 14

6 Methods

Rating   Name   Duplication   Size   Complexity  
A RandomSpike.get_params() 0 10 1
A Spike.__init__() 0 11 1
A Spike.apply_transform() 0 17 4
A RandomSpike.apply_transform() 0 12 2
A RandomSpike.__init__() 0 11 1
B Spike.add_artifact() 0 29 5
1
from collections import defaultdict
2
from typing import Tuple, Union, Dict
3
4
import torch
5
import numpy as np
6
7
from ....data.subject import Subject
8
from ... import IntensityTransform, FourierTransform
9
from .. import RandomTransform
10
11
12
class RandomSpike(RandomTransform, IntensityTransform, FourierTransform):
13
    r"""Add random MRI spike artifacts.
14
15
    Also known as `Herringbone artifact
16
    <https://radiopaedia.org/articles/herringbone-artifact?lang=gb>`_,
17
    crisscross artifact or corduroy artifact, it creates stripes in different
18
    directions in image space due to spikes in k-space.
19
20
    Args:
21
        num_spikes: Number of spikes :math:`n` present in k-space.
22
            If a tuple :math:`(a, b)` is provided, then
23
            :math:`n \sim \mathcal{U}(a, b) \cap \mathbb{N}`.
24
            If only one value :math:`d` is provided,
25
            :math:`n \sim \mathcal{U}(0, d) \cap \mathbb{N}`.
26
            Larger values generate more distorted images.
27
        intensity: Ratio :math:`r` between the spike intensity and the maximum
28
            of the spectrum.
29
            If a tuple :math:`(a, b)` is provided, then
30
            :math:`r \sim \mathcal{U}(a, b)`.
31
            If only one value :math:`d` is provided,
32
            :math:`r \sim \mathcal{U}(-d, d)`.
33
            Larger values generate more distorted images.
34
        **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments.
35
36
    .. note:: The execution time of this transform does not depend on the
37
        number of spikes.
38
    """
39
    def __init__(
40
            self,
41
            num_spikes: Union[int, Tuple[int, int]] = 1,
42
            intensity: Union[float, Tuple[float, float]] = (1, 3),
43
            **kwargs
44
            ):
45
        super().__init__(**kwargs)
46
        self.intensity_range = self._parse_range(
47
            intensity, 'intensity_range')
48
        self.num_spikes_range = self._parse_range(
49
            num_spikes, 'num_spikes', min_constraint=0, type_constraint=int)
50
51
    def apply_transform(self, subject: Subject) -> Subject:
52
        arguments = defaultdict(dict)
53
        for image_name in self.get_images_dict(subject):
54
            spikes_positions_param, intensity_param = self.get_params(
55
                self.num_spikes_range,
56
                self.intensity_range,
57
            )
58
            arguments['spikes_positions'][image_name] = spikes_positions_param
59
            arguments['intensity'][image_name] = intensity_param
60
        transform = Spike(**self.add_include_exclude(arguments))
61
        transformed = transform(subject)
62
        return transformed
63
64
    def get_params(
65
            self,
66
            num_spikes_range: Tuple[int, int],
67
            intensity_range: Tuple[float, float],
68
            ) -> Tuple[np.ndarray, float]:
69
        ns_min, ns_max = num_spikes_range
70
        num_spikes_param = torch.randint(ns_min, ns_max + 1, (1,)).item()
71
        intensity_param = self.sample_uniform(*intensity_range)
72
        spikes_positions = torch.rand(num_spikes_param, 3).numpy()
73
        return spikes_positions, intensity_param.item()
74
75
76
class Spike(IntensityTransform, FourierTransform):
77
    r"""Add MRI spike artifacts.
78
79
    Also known as `Herringbone artifact
80
    <https://radiopaedia.org/articles/herringbone-artifact?lang=gb>`_,
81
    crisscross artifact or corduroy artifact, it creates stripes in different
82
    directions in image space due to spikes in k-space.
83
84
    Args:
85
        spikes_positions:
86
        intensity: Ratio :math:`r` between the spike intensity and the maximum
87
            of the spectrum.
88
        **kwargs: See :class:`~torchio.transforms.Transform` for additional keyword arguments.
89
90
    .. note:: The execution time of this transform does not depend on the
91
        number of spikes.
92
    """
93
    def __init__(
94
            self,
95
            spikes_positions: Union[np.ndarray, Dict[str, np.ndarray]],
96
            intensity: Union[float, Dict[str, float]],
97
            **kwargs
98
            ):
99
        super().__init__(**kwargs)
100
        self.spikes_positions = spikes_positions
101
        self.intensity = intensity
102
        self.args_names = 'spikes_positions', 'intensity'
103
        self.invert_transform = False
104
105
    def apply_transform(self, subject: Subject) -> Subject:
106
        spikes_positions = self.spikes_positions
107
        intensity = self.intensity
108
        for image_name, image in self.get_images_dict(subject).items():
109
            if self.arguments_are_dict():
110
                spikes_positions = self.spikes_positions[image_name]
111
                intensity = self.intensity[image_name]
112
            transformed_tensors = []
113
            for channel in image.data:
114
                transformed_tensor = self.add_artifact(
115
                    channel,
116
                    spikes_positions,
117
                    intensity,
118
                )
119
                transformed_tensors.append(transformed_tensor)
120
            image.data = torch.stack(transformed_tensors)
121
        return subject
122
123
    def add_artifact(
124
            self,
125
            tensor: torch.Tensor,
126
            spikes_positions: np.ndarray,
127
            intensity_factor: float,
128
            ):
129
        if intensity_factor == 0 or len(spikes_positions) == 0:
130
            return tensor
131
        spectrum = self.fourier_transform(tensor)
132
        shape = np.array(spectrum.shape)
133
        mid_shape = shape // 2
134
        indices = np.floor(spikes_positions * shape).astype(int)
135
        for index in indices:
136
            diff = index - mid_shape
137
            i, j, k = mid_shape + diff
138
            # As of torch 1.7, "max is not yet implemented for complex tensors"
139
            artifact = spectrum.cpu().numpy().max() * intensity_factor
140
            if self.invert_transform:
141
                spectrum[i, j, k] -= artifact
142
            else:
143
                spectrum[i, j, k] += artifact
144
            # If we wanted to add a pure cosine, we should add spikes to both
145
            # sides of k-space. However, having only one is a better
146
            # representation og the actual cause of the artifact in real
147
            # scans. Therefore the next two lines have been removed.
148
            # #i, j, k = mid_shape - diff
149
            # #spectrum[i, j, k] = spectrum.max() * intensity_factor
150
        result = self.inv_fourier_transform(spectrum).real.float()
151
        return result
152