Passed
Pull Request — master (#625)
by Fernando
01:19
created

SlabProjection.__init__()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 22
rs 9.4
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from numbers import Real
2
from typing import Union, Optional, Callable
3
4
import torch
5
from torchio.data.image import ScalarImage
6
from ....data.subject import Subject
7
from ...intensity_transform import IntensityTransform
8
9
10
class SlabProjection(IntensityTransform):
11
    """Project intensities along a given axis, possibly with sliding slabs.
12
13
    Args:
14
        axis: Index for the spatial dimension to project across.
15
            See :class:`~.torchio.RandomFlip` for information on the accepted
16
            types.
17
        slab_thickness: Thickness of slab projections. In other words, the
18
            number of voxels in the ``axis`` dimension to project across.
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections. Default is 1.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'percentile'``. If ``'percentile'`` is used, the ``percentile``
27
            argument must also be supplied.
28
        percentile: Percetile to use for intensity projections. This argument
29
            is required if ``projection_type`` is ``'percentile'`` and is
30
            silently ignored otherwise.
31
        full_slabs_only: Boolean. Should projections be done only for slabs
32
            that are ``slab_thickness`` thick? Default is ``True``.
33
            If ``False``, some slabs may not be ``slab_thickness`` thick
34
            depending on the size of the image, slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
39
        >>> axial_mip = tio.SlabProjection("S", slab_thickness=20)
40
        >>> ct_t = axial_mip(ct)
41
        >>> ct_t.plot()
42
43
    .. plot::
44
45
        import torchio as tio
46
        sub = tio.datasets.Slicer('CTChest')
47
        ct = sub.CT_chest
48
        axial_mip = tio.SlabProjection("S", slab_thickness=20)
49
        ct_mip = axial_mip(ct)
50
        sub.add_image(ct_mip, 'CT_MIP')
51
        sub = tio.Clamp(-1000, 1000)(sub)
52
        sub.plot()
53
54
    """
55
    def __init__(
56
            self,
57
            axis: Union[int, str],
58
            slab_thickness: Optional[int] = None,
59
            stride: int = 1,
60
            projection_type: str = 'max',
61
            percentile: Optional[float] = None,
62
            full_slabs_only: bool = True,
63
            **kwargs
64
            ):
65
        super().__init__(**kwargs)
66
        self.args_names = (
67
            'axis', 'slab_thickness', 'stride',
68
            'projection_type', 'percentile', 'full_slabs_only'
69
            )
70
        self.axis = self.parse_axes(axis)[0]
71
        self.slab_thickness = slab_thickness
72
        self.stride = stride
73
        self.projection_fun = self.get_projection_function(projection_type)
74
        self.projection_type = projection_type
75
        self.percentile = self.validate_percentile(percentile, projection_type)
76
        self.full_slabs_only = full_slabs_only
77
78
    @staticmethod
79
    def validate_percentile(percentile, projection_type):
80
        if not projection_type == 'percentile':
81
            return percentile
82
        message = (
83
            "For projection_type='percentile', `percentile` must be a scalar"
84
            f' value in the range [0, 1], but "{percentile}" was passed'
85
        )
86
        if not isinstance(percentile, Real):
87
            raise TypeError(message)
88
        elif 0 <= percentile <= 100:
89
            return percentile / 100
90
        else:
91
            raise ValueError(message)
92
93
    @staticmethod
94
    def get_projection_function(projection_type: str) -> Callable:
95
        arg_to_function = {
96
            'max': 'amax',
97
            'min': 'amin',
98
            'mean': 'mean',
99
            'median': 'median',
100
            'percentile': 'quantile',
101
        }
102
        try:
103
            function_name = arg_to_function[projection_type]
104
        except KeyError:
105
            message = (
106
                f'The projection type must be in {arg_to_function.keys()}, '
107
                f' but {projection_type} was passed'
108
            )
109
            raise ValueError(message)
110
        projection_function = getattr(torch, function_name)
111
        return projection_function
112
113
    def get_num_slabs(self, axis_span: int) -> int:
114
        if self.full_slabs_only:
115
            start_index = 0
116
            num_slabs = 0
117
            while start_index + self.slab_thickness <= axis_span:
118
                num_slabs += 1
119
                start_index += self.stride
120
        else:
121
            num_slabs = torch.ceil(torch.tensor(axis_span) / self.stride)
122
            num_slabs = int(num_slabs.item())
123
        return num_slabs
124
125
    def apply_transform(self, subject: Subject) -> Subject:
126
        axis_index = self.ensure_axes_indices(subject, [self.axis])[0]
127
        for image in self.get_images(subject):
128
            self.apply_projection(image, axis_index)
129
        return subject
130
131
    def apply_projection(self, image: ScalarImage, axis_index: int) -> None:
132
        axis_span = image.shape[axis_index]
133
        if self.slab_thickness is None:
134
            self.slab_thickness = axis_span
135
        elif self.slab_thickness > axis_span:
136
            self.slab_thickness = axis_span
137
        image.set_data(self.projection(image.data, axis_index, axis_span))
138
139
    def projection(
140
            self,
141
            tensor: torch.Tensor,
142
            axis_index: int,
143
            axis_span: int,
144
            ) -> torch.Tensor:
145
        if self.projection_type in ['mean', 'percentile']:
146
            tensor = tensor.to(torch.float)
147
148
        num_slabs = self.get_num_slabs(axis_span)
149
150
        slabs = []
151
        start_index = 0
152
        end_index = start_index + self.slab_thickness
153
154
        for _ in range(num_slabs):
155
            slab_indices = torch.arange(start_index, end_index)
156
            slab = tensor.index_select(axis_index, slab_indices)
157
            if self.projection_type == 'median':
158
                projected, _ = self.projection_fun(
159
                    slab, dim=axis_index, keepdim=True)
160
            elif self.projection_type == 'percentile':
161
                projected = self.projection_fun(
162
                    slab, q=self.percentile, dim=axis_index,
163
                    keepdim=True)
164
            else:
165
                projected = self.projection_fun(
166
                    slab, dim=axis_index, keepdim=True)
167
            slabs.append(projected)
168
            start_index += self.stride
169
            end_index = start_index + self.slab_thickness
170
            if end_index > axis_span:
171
                end_index = axis_span
172
173
        return torch.cat(slabs, dim=axis_index)
174