Passed
Pull Request — master (#625)
by Fernando
03:33
created

torchio.transforms.preprocessing.intensity.slab_projection   A

Complexity

Total Complexity 21

Size/Duplication

Total Lines 169
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 21
eloc 102
dl 0
loc 169
rs 10
c 0
b 0
f 0

7 Methods

Rating   Name   Duplication   Size   Complexity  
A SlabProjection.validate_percentile() 0 14 4
A SlabProjection.apply_projection() 0 8 3
A SlabProjection.__init__() 0 22 1
A SlabProjection.apply_transform() 0 4 2
A SlabProjection.get_projection_function() 0 19 2
A SlabProjection.get_num_slabs() 0 11 3
B SlabProjection.projection() 0 30 6
1
from numbers import Real
2
from typing import Union, Optional, Callable
3
4
import torch
5
from torchio.data.image import ScalarImage
6
from ....data.subject import Subject
7
from ...intensity_transform import IntensityTransform
8
9
10
class SlabProjection(IntensityTransform):
11
    """Project intensities along a given axis, possibly with sliding slabs.
12
13
    Args:
14
        axis: Index for the spatial dimension to project across.
15
            See :class:`~.torchio.RandomFlip` for information on the accepted
16
            types.
17
        slab_thickness: Thickness of slab projections. In other words, the
18
            number of voxels in the ``axis`` dimension to project across.
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections. Default is 1.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'percentile'``. If ``'percentile'`` is used, the ``percentile``
27
            argument must also be supplied.
28
        percentile: Percetile to use for intensity projections. This argument
29
            is required if ``projection_type`` is ``'percentile'`` and is
30
            silently ignored otherwise.
31
        full_slabs_only: Boolean. Should projections be done only for slabs
32
            that are ``slab_thickness`` thick? Default is ``True``.
33
            If ``False``, some slabs may not be ``slab_thickness`` thick
34
            depending on the size of the image, slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
39
        >>> axial_mip = tio.SlabProjection("S", slab_thickness=20)
40
        >>> ct_t = axial_mip(ct)
41
        >>> ct_t.plot()
42
43
    .. plot::
44
45
        import torchio as tio
46
        sub = tio.datasets.Slicer('CTChest')
47
        ct = sub.CT_chest
48
        axial_mip = tio.SlabProjection("S", slab_thickness=20)
49
        ct_mip = axial_mip(ct)
50
        sub.add_image(ct_mip, 'CT_MIP')
51
        sub = tio.Clamp(-1000, 1000)(sub)
52
        sub.plot()
53
54
    """
55
    def __init__(
56
            self,
57
            axis: Union[int, str],
58
            slab_thickness: Optional[int] = None,
59
            stride: int = 1,
60
            projection_type: str = 'max',
61
            percentile: Optional[float] = None,
62
            full_slabs_only: bool = True,
63
            **kwargs
64
            ):
65
        super().__init__(**kwargs)
66
        self.args_names = (
67
            'axis', 'slab_thickness', 'stride',
68
            'projection_type', 'percentile', 'full_slabs_only'
69
            )
70
        self.axis = axis
71
        self.slab_thickness = slab_thickness
72
        self.stride = stride
73
        self.projection_fun = self.get_projection_function(projection_type)
74
        self.projection_type = projection_type
75
        self.percentile = self.validate_percentile(percentile, projection_type)
76
        self.full_slabs_only = full_slabs_only
77
78
    @staticmethod
79
    def validate_percentile(percentile, projection_type):
80
        if not projection_type == 'percentile':
81
            return percentile
82
        message = (
83
            "For projection_type='percentile', `percentile` must be a scalar"
84
            f' value in the range [0, 1], but "{percentile}" was passed'
85
        )
86
        if not isinstance(percentile, Real):
87
            raise TypeError(message)
88
        elif 0 <= percentile <= 100:
89
            return percentile / 100
90
        else:
91
            raise ValueError(message)
92
93
    @staticmethod
94
    def get_projection_function(projection_type: str) -> Callable:
95
        arg_to_function = {
96
            'max': 'amax',
97
            'min': 'amin',
98
            'mean': 'mean',
99
            'median': 'median',
100
            'percentile': 'quantile',
101
        }
102
        try:
103
            function_name = arg_to_function[projection_type]
104
        except KeyError:
105
            message = (
106
                f'The projection type must be in {arg_to_function.keys()}, '
107
                f' but {projection_type} was passed'
108
            )
109
            raise ValueError(message)
110
        projection_function = getattr(torch, function_name)
111
        return projection_function
112
113
    def get_num_slabs(self) -> int:
114
        if self.full_slabs_only:
115
            start_index = 0
116
            num_slabs = 0
117
            while start_index + self.slab_thickness <= self.axis_span:
118
                num_slabs += 1
119
                start_index += self.stride
120
        else:
121
            num_slabs = torch.ceil(torch.tensor(self.axis_span) / self.stride)
122
            num_slabs = int(num_slabs.item())
123
        return num_slabs
124
125
    def apply_transform(self, subject: Subject) -> Subject:
126
        for image in self.get_images(subject):
127
            self.apply_projection(image)
128
        return subject
129
130
    def apply_projection(self, image: ScalarImage) -> None:
131
        self.axis_index = image.axis_name_to_index(self.axis)
132
        self.axis_span = image.shape[self.axis_index]
133
        if self.slab_thickness is None:
134
            self.slab_thickness = self.axis_span
135
        elif self.slab_thickness > self.axis_span:
136
            self.slab_thickness = self.axis_span
137
        image.set_data(self.projection(image.data))
138
139
    def projection(self, tensor: torch.Tensor) -> torch.Tensor:
140
        if self.projection_type in ['mean', 'percentile']:
141
            tensor = tensor.to(torch.float)
142
143
        num_slabs = self.get_num_slabs()
144
145
        slabs = []
146
        start_index = 0
147
        end_index = start_index + self.slab_thickness
148
149
        for _ in range(num_slabs):
150
            slab_indices = torch.arange(start_index, end_index)
151
            slab = tensor.index_select(self.axis_index, slab_indices)
152
            if self.projection_type == 'median':
153
                projected, _ = self.projection_fun(
154
                    slab, dim=self.axis_index, keepdim=True)
155
            elif self.projection_type == 'percentile':
156
                projected = self.projection_fun(
157
                    slab, q=self.percentile, dim=self.axis_index,
158
                    keepdim=True)
159
            else:
160
                projected = self.projection_fun(
161
                    slab, dim=self.axis_index, keepdim=True)
162
            slabs.append(projected)
163
            start_index += self.stride
164
            end_index = start_index + self.slab_thickness
165
            if end_index > self.axis_span:
166
                end_index = self.axis_span
167
168
        return torch.cat(slabs, dim=self.axis_index)
169