Passed
Pull Request — master (#625)
by Fernando
01:19
created

SlabProjection.validate_percentile()   A

Complexity

Conditions 4

Size

Total Lines 14
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 11
dl 0
loc 14
rs 9.85
c 0
b 0
f 0
cc 4
nop 2
1
from numbers import Real
2
from typing import Union, Optional, Callable
3
4
import torch
5
from torchio.data.image import ScalarImage
6
from ....data.subject import Subject
7
from ...intensity_transform import IntensityTransform
8
9
10
class SlabProjection(IntensityTransform):
11
    """Project intensities along a given axis, possibly with sliding slabs.
12
13
    Args:
14
        axis: Index for the spatial dimension to project across.
15
            See :class:`~.torchio.RandomFlip` for information on the accepted
16
            types.
17
        slab_thickness: Thickness of slab projections. In other words, the
18
            number of voxels in the ``axis`` dimension to project across.
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections. Default is 1.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'percentile'``. If ``'percentile'`` is used, the ``percentile``
27
            argument must also be supplied.
28
        percentile: Percetile to use for intensity projections. This argument
29
            is required if ``projection_type`` is ``'percentile'`` and is
30
            silently ignored otherwise.
31
        full_slabs_only: Boolean. Should projections be done only for slabs
32
            that are ``slab_thickness`` thick? Default is ``True``.
33
            If ``False``, some slabs may not be ``slab_thickness`` thick
34
            depending on the size of the image, slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
39
        >>> axial_mip = tio.SlabProjection("S", slab_thickness=20)
40
        >>> ct_t = axial_mip(ct)
41
        >>> ct_t.plot()
42
43
    .. plot::
44
45
        import torchio as tio
46
        sub = tio.datasets.Slicer('CTChest')
47
        ct = sub.CT_chest
48
        axial_mip = tio.SlabProjection("S", slab_thickness=20)
49
        ct_mip = axial_mip(ct)
50
        sub.add_image(ct_mip, 'CT_MIP')
51
        sub = tio.Clamp(-1000, 1000)(sub)
52
        sub.plot()
53
54
    """
55
    def __init__(
56
            self,
57
            axis: Union[int, str],
58
            slab_thickness: Optional[int] = None,
59
            stride: int = 1,
60
            projection_type: str = 'max',
61
            percentile: Optional[float] = None,
62
            full_slabs_only: bool = True,
63
            **kwargs
64
            ):
65
        super().__init__(**kwargs)
66
        self.args_names = (
67
            'axis', 'slab_thickness', 'stride',
68
            'projection_type', 'percentile', 'full_slabs_only'
69
            )
70
        self.axis = self.parse_axes(axis)[0]
71
        self.slab_thickness = slab_thickness
72
        self.stride = stride
73
        self.projection_fun = self.get_projection_function(projection_type)
74
        self.projection_type = projection_type
75
        self.percentile = self.validate_percentile(percentile, projection_type)
76
        self.full_slabs_only = full_slabs_only
77
78
    @staticmethod
79
    def validate_percentile(percentile, projection_type):
80
        if not projection_type == 'percentile':
81
            return percentile
82
        message = (
83
            "For projection_type='percentile', `percentile` must be a scalar"
84
            f' value in the range [0, 1], but "{percentile}" was passed'
85
        )
86
        if not isinstance(percentile, Real):
87
            raise TypeError(message)
88
        elif 0 <= percentile <= 100:
89
            return percentile / 100
90
        else:
91
            raise ValueError(message)
92
93
    @staticmethod
94
    def get_projection_function(projection_type: str) -> Callable:
95
        arg_to_function = {
96
            'max': 'amax',
97
            'min': 'amin',
98
            'mean': 'mean',
99
            'median': 'median',
100
            'percentile': 'quantile',
101
        }
102
        try:
103
            function_name = arg_to_function[projection_type]
104
        except KeyError:
105
            message = (
106
                f'The projection type must be in {arg_to_function.keys()}, '
107
                f' but {projection_type} was passed'
108
            )
109
            raise ValueError(message)
110
        projection_function = getattr(torch, function_name)
111
        return projection_function
112
113
    def get_num_slabs(self, axis_span: int) -> int:
114
        if self.full_slabs_only:
115
            start_index = 0
116
            num_slabs = 0
117
            while start_index + self.slab_thickness <= axis_span:
118
                num_slabs += 1
119
                start_index += self.stride
120
        else:
121
            num_slabs = torch.ceil(torch.tensor(axis_span) / self.stride)
122
            num_slabs = int(num_slabs.item())
123
        return num_slabs
124
125
    def apply_transform(self, subject: Subject) -> Subject:
126
        axis_index = self.ensure_axes_indices(subject, [self.axis])[0]
127
        for image in self.get_images(subject):
128
            self.apply_projection(image, axis_index)
129
        return subject
130
131
    def apply_projection(self, image: ScalarImage, axis_index: int) -> None:
132
        axis_span = image.shape[axis_index]
133
        if self.slab_thickness is None:
134
            self.slab_thickness = axis_span
135
        elif self.slab_thickness > axis_span:
136
            self.slab_thickness = axis_span
137
        image.set_data(self.projection(image.data, axis_index, axis_span))
138
139
    def projection(
140
            self,
141
            tensor: torch.Tensor,
142
            axis_index: int,
143
            axis_span: int,
144
            ) -> torch.Tensor:
145
        if self.projection_type in ['mean', 'percentile']:
146
            tensor = tensor.to(torch.float)
147
148
        num_slabs = self.get_num_slabs(axis_span)
149
150
        slabs = []
151
        start_index = 0
152
        end_index = start_index + self.slab_thickness
153
154
        for _ in range(num_slabs):
155
            slab_indices = torch.arange(start_index, end_index)
156
            slab = tensor.index_select(axis_index, slab_indices)
157
            if self.projection_type == 'median':
158
                projected, _ = self.projection_fun(
159
                    slab, dim=axis_index, keepdim=True)
160
            elif self.projection_type == 'percentile':
161
                projected = self.projection_fun(
162
                    slab, q=self.percentile, dim=axis_index,
163
                    keepdim=True)
164
            else:
165
                projected = self.projection_fun(
166
                    slab, dim=axis_index, keepdim=True)
167
            slabs.append(projected)
168
            start_index += self.stride
169
            end_index = start_index + self.slab_thickness
170
            if end_index > axis_span:
171
                end_index = axis_span
172
173
        return torch.cat(slabs, dim=axis_index)
174