Passed
Pull Request — master (#625)
by Fernando
01:52
created

torchio.transforms.preprocessing.intensity.slab_projection   A

Complexity

Total Complexity 21

Size/Duplication

Total Lines 175
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 21
eloc 107
dl 0
loc 175
rs 10
c 0
b 0
f 0

7 Methods

Rating   Name   Duplication   Size   Complexity  
A SlabProjection.validate_percentile() 0 14 4
A SlabProjection.__init__() 0 22 1
A SlabProjection.get_projection_function() 0 19 2
A SlabProjection.get_num_slabs() 0 11 3
A SlabProjection.apply_projection() 0 7 3
A SlabProjection.apply_transform() 0 6 2
B SlabProjection.projection() 0 35 6
1
from numbers import Real
2
from typing import Union, Optional, Callable
3
4
import torch
5
from torchio.data.image import ScalarImage
6
from ....data.subject import Subject
7
from ...intensity_transform import IntensityTransform
8
9
10
class SlabProjection(IntensityTransform):
11
    """Project intensities along a given axis, possibly with sliding slabs.
12
13
    Args:
14
        axis: Index for the spatial dimension to project across.
15
            See :class:`~.torchio.RandomFlip` for information on the accepted
16
            options.
17
        slab_thickness: Thickness of slab projections (number of voxels along
18
            the ``axis`` dimension to project across).
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections. Default is 1.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'``, ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'percentile'``. If ``'percentile'`` is used, the ``percentile``
27
            argument must also be supplied.
28
        percentile: Percentile to use for intensity projections. This argument
29
            is required if ``projection_type`` is ``'percentile'`` and is
30
            silently ignored otherwise.
31
        full_slabs_only: If ``True``, projections will be done only for slabs
32
            that are ``slab_thickness`` thick. If ``False``, some slabs may not
33
            be ``slab_thickness`` thick depending on the size of the image,
34
            slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
39
        >>> axial_mip = tio.SlabProjection("S", slab_thickness=20)
40
        >>> ct_t = axial_mip(ct)
41
        >>> ct_t.plot()
42
43
    .. plot::
44
45
        import torchio as tio
46
        sub = tio.datasets.Slicer('CTChest')
47
        ct = sub.CT_chest
48
        axial_mip = tio.SlabProjection("S", slab_thickness=20)
49
        ct_mip = axial_mip(ct)
50
        sub.add_image(ct_mip, 'CT_MIP')
51
        sub = tio.Clamp(-1000, 1000)(sub)
52
        sub.plot()
53
54
    """
55
    def __init__(
56
            self,
57
            axis: Union[int, str],
58
            slab_thickness: Optional[int] = None,
59
            stride: int = 1,
60
            projection_type: str = 'max',
61
            percentile: Optional[float] = None,
62
            full_slabs_only: bool = True,
63
            **kwargs
64
            ):
65
        super().__init__(**kwargs)
66
        self.args_names = (
67
            'axis', 'slab_thickness', 'stride',
68
            'projection_type', 'percentile', 'full_slabs_only'
69
            )
70
        self.axis = self.parse_axes(axis)[0]
71
        self.slab_thickness = slab_thickness
72
        self.stride = stride
73
        self.projection_fun = self.get_projection_function(projection_type)
74
        self.projection_type = projection_type
75
        self.percentile = self.validate_percentile(percentile, projection_type)
76
        self.full_slabs_only = full_slabs_only
77
78
    @staticmethod
79
    def validate_percentile(percentile, projection_type):
80
        if not projection_type == 'percentile':
81
            return percentile
82
        message = (
83
            "For projection_type='percentile', `percentile` must be a scalar"
84
            f' value in the range [0, 1], but "{percentile}" was passed'
85
        )
86
        if not isinstance(percentile, Real):
87
            raise TypeError(message)
88
        elif 0 <= percentile <= 100:
89
            return percentile / 100
90
        else:
91
            raise ValueError(message)
92
93
    @staticmethod
94
    def get_projection_function(projection_type: str) -> Callable:
95
        arg_to_function = {
96
            'max': 'amax',
97
            'min': 'amin',
98
            'mean': 'mean',
99
            'median': 'median',
100
            'percentile': 'quantile',
101
        }
102
        try:
103
            function_name = arg_to_function[projection_type]
104
        except KeyError:
105
            message = (
106
                f'The projection type must be in {arg_to_function.keys()}, '
107
                f' but {projection_type} was passed'
108
            )
109
            raise ValueError(message)
110
        projection_function = getattr(torch, function_name)
111
        return projection_function
112
113
    def get_num_slabs(self, axis_span: int) -> int:
114
        if self.full_slabs_only:
115
            start_index = 0
116
            num_slabs = 0
117
            while start_index + self.slab_thickness <= axis_span:
118
                num_slabs += 1
119
                start_index += self.stride
120
        else:
121
            num_slabs = torch.ceil(torch.tensor(axis_span) / self.stride)
122
            num_slabs = int(num_slabs.item())
123
        return num_slabs
124
125
    def apply_transform(self, subject: Subject) -> Subject:
126
        spatial_axis_index = self.ensure_axes_indices(subject, [self.axis])[0]
127
        tensor_axis_index = spatial_axis_index + 1  # channels is 0
128
        for image in self.get_images(subject):
129
            self.apply_projection(image, tensor_axis_index)
130
        return subject
131
132
    def apply_projection(self, image: ScalarImage, axis_index: int) -> None:
133
        axis_span = image.shape[axis_index]
134
        if self.slab_thickness is None:
135
            self.slab_thickness = axis_span
136
        elif self.slab_thickness > axis_span:
137
            self.slab_thickness = axis_span
138
        image.set_data(self.projection(image.data, axis_index, axis_span))
139
140
    def projection(
141
            self,
142
            tensor: torch.Tensor,
143
            axis_index: int,
144
            axis_span: int,
145
            ) -> torch.Tensor:
146
        if self.projection_type in ['mean', 'percentile']:
147
            tensor = tensor.to(torch.float)
148
149
        num_slabs = self.get_num_slabs(axis_span)
150
151
        slabs = []
152
        start_index = 0
153
        end_index = start_index + self.slab_thickness
154
155
        for _ in range(num_slabs):
156
            slab_indices = torch.arange(start_index, end_index)
157
            slab = tensor.index_select(axis_index, slab_indices)
158
            if self.projection_type == 'median':
159
                projected, _ = self.projection_fun(
160
                    slab, dim=axis_index, keepdim=True)
161
            elif self.projection_type == 'percentile':
162
                projected = self.projection_fun(
163
                    slab, q=self.percentile, dim=axis_index,
164
                    keepdim=True)
165
            else:
166
                projected = self.projection_fun(
167
                    slab, dim=axis_index, keepdim=True)
168
            slabs.append(projected)
169
            start_index += self.stride
170
            end_index = start_index + self.slab_thickness
171
            if end_index > axis_span:
172
                end_index = axis_span
173
174
        return torch.cat(slabs, dim=axis_index)
175