Passed
Pull Request — master (#625)
by Fernando
01:52
created

SlabProjection.__init__()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 22
rs 9.4
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from numbers import Real
2
from typing import Union, Optional, Callable
3
4
import torch
5
from torchio.data.image import ScalarImage
6
from ....data.subject import Subject
7
from ...intensity_transform import IntensityTransform
8
9
10
class SlabProjection(IntensityTransform):
11
    """Project intensities along a given axis, possibly with sliding slabs.
12
13
    Args:
14
        axis: Index for the spatial dimension to project across.
15
            See :class:`~.torchio.RandomFlip` for information on the accepted
16
            options.
17
        slab_thickness: Thickness of slab projections (number of voxels along
18
            the ``axis`` dimension to project across).
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections. Default is 1.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'``, ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'percentile'``. If ``'percentile'`` is used, the ``percentile``
27
            argument must also be supplied.
28
        percentile: Percentile to use for intensity projections. This argument
29
            is required if ``projection_type`` is ``'percentile'`` and is
30
            silently ignored otherwise.
31
        full_slabs_only: If ``True``, projections will be done only for slabs
32
            that are ``slab_thickness`` thick. If ``False``, some slabs may not
33
            be ``slab_thickness`` thick depending on the size of the image,
34
            slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
39
        >>> axial_mip = tio.SlabProjection("S", slab_thickness=20)
40
        >>> ct_t = axial_mip(ct)
41
        >>> ct_t.plot()
42
43
    .. plot::
44
45
        import torchio as tio
46
        sub = tio.datasets.Slicer('CTChest')
47
        ct = sub.CT_chest
48
        axial_mip = tio.SlabProjection("S", slab_thickness=20)
49
        ct_mip = axial_mip(ct)
50
        sub.add_image(ct_mip, 'CT_MIP')
51
        sub = tio.Clamp(-1000, 1000)(sub)
52
        sub.plot()
53
54
    """
55
    def __init__(
56
            self,
57
            axis: Union[int, str],
58
            slab_thickness: Optional[int] = None,
59
            stride: int = 1,
60
            projection_type: str = 'max',
61
            percentile: Optional[float] = None,
62
            full_slabs_only: bool = True,
63
            **kwargs
64
            ):
65
        super().__init__(**kwargs)
66
        self.args_names = (
67
            'axis', 'slab_thickness', 'stride',
68
            'projection_type', 'percentile', 'full_slabs_only'
69
            )
70
        self.axis = self.parse_axes(axis)[0]
71
        self.slab_thickness = slab_thickness
72
        self.stride = stride
73
        self.projection_fun = self.get_projection_function(projection_type)
74
        self.projection_type = projection_type
75
        self.percentile = self.validate_percentile(percentile, projection_type)
76
        self.full_slabs_only = full_slabs_only
77
78
    @staticmethod
79
    def validate_percentile(percentile, projection_type):
80
        if not projection_type == 'percentile':
81
            return percentile
82
        message = (
83
            "For projection_type='percentile', `percentile` must be a scalar"
84
            f' value in the range [0, 1], but "{percentile}" was passed'
85
        )
86
        if not isinstance(percentile, Real):
87
            raise TypeError(message)
88
        elif 0 <= percentile <= 100:
89
            return percentile / 100
90
        else:
91
            raise ValueError(message)
92
93
    @staticmethod
94
    def get_projection_function(projection_type: str) -> Callable:
95
        arg_to_function = {
96
            'max': 'amax',
97
            'min': 'amin',
98
            'mean': 'mean',
99
            'median': 'median',
100
            'percentile': 'quantile',
101
        }
102
        try:
103
            function_name = arg_to_function[projection_type]
104
        except KeyError:
105
            message = (
106
                f'The projection type must be in {arg_to_function.keys()}, '
107
                f' but {projection_type} was passed'
108
            )
109
            raise ValueError(message)
110
        projection_function = getattr(torch, function_name)
111
        return projection_function
112
113
    def get_num_slabs(self, axis_span: int) -> int:
114
        if self.full_slabs_only:
115
            start_index = 0
116
            num_slabs = 0
117
            while start_index + self.slab_thickness <= axis_span:
118
                num_slabs += 1
119
                start_index += self.stride
120
        else:
121
            num_slabs = torch.ceil(torch.tensor(axis_span) / self.stride)
122
            num_slabs = int(num_slabs.item())
123
        return num_slabs
124
125
    def apply_transform(self, subject: Subject) -> Subject:
126
        spatial_axis_index = self.ensure_axes_indices(subject, [self.axis])[0]
127
        tensor_axis_index = spatial_axis_index + 1  # channels is 0
128
        for image in self.get_images(subject):
129
            self.apply_projection(image, tensor_axis_index)
130
        return subject
131
132
    def apply_projection(self, image: ScalarImage, axis_index: int) -> None:
133
        axis_span = image.shape[axis_index]
134
        if self.slab_thickness is None:
135
            self.slab_thickness = axis_span
136
        elif self.slab_thickness > axis_span:
137
            self.slab_thickness = axis_span
138
        image.set_data(self.projection(image.data, axis_index, axis_span))
139
140
    def projection(
141
            self,
142
            tensor: torch.Tensor,
143
            axis_index: int,
144
            axis_span: int,
145
            ) -> torch.Tensor:
146
        if self.projection_type in ['mean', 'percentile']:
147
            tensor = tensor.to(torch.float)
148
149
        num_slabs = self.get_num_slabs(axis_span)
150
151
        slabs = []
152
        start_index = 0
153
        end_index = start_index + self.slab_thickness
154
155
        for _ in range(num_slabs):
156
            slab_indices = torch.arange(start_index, end_index)
157
            slab = tensor.index_select(axis_index, slab_indices)
158
            if self.projection_type == 'median':
159
                projected, _ = self.projection_fun(
160
                    slab, dim=axis_index, keepdim=True)
161
            elif self.projection_type == 'percentile':
162
                projected = self.projection_fun(
163
                    slab, q=self.percentile, dim=axis_index,
164
                    keepdim=True)
165
            else:
166
                projected = self.projection_fun(
167
                    slab, dim=axis_index, keepdim=True)
168
            slabs.append(projected)
169
            start_index += self.stride
170
            end_index = start_index + self.slab_thickness
171
            if end_index > axis_span:
172
                end_index = axis_span
173
174
        return torch.cat(slabs, dim=axis_index)
175