Passed
Pull Request — master (#625)
by Fernando
03:33
created

SlabProjection.__init__()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 22
rs 9.4
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
from numbers import Real
2
from typing import Union, Optional, Callable
3
4
import torch
5
from torchio.data.image import ScalarImage
6
from ....data.subject import Subject
7
from ...intensity_transform import IntensityTransform
8
9
10
class SlabProjection(IntensityTransform):
11
    """Project intensities along a given axis, possibly with sliding slabs.
12
13
    Args:
14
        axis: Index for the spatial dimension to project across.
15
            See :class:`~.torchio.RandomFlip` for information on the accepted
16
            types.
17
        slab_thickness: Thickness of slab projections. In other words, the
18
            number of voxels in the ``axis`` dimension to project across.
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections. Default is 1.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'percentile'``. If ``'percentile'`` is used, the ``percentile``
27
            argument must also be supplied.
28
        percentile: Percetile to use for intensity projections. This argument
29
            is required if ``projection_type`` is ``'percentile'`` and is
30
            silently ignored otherwise.
31
        full_slabs_only: Boolean. Should projections be done only for slabs
32
            that are ``slab_thickness`` thick? Default is ``True``.
33
            If ``False``, some slabs may not be ``slab_thickness`` thick
34
            depending on the size of the image, slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
39
        >>> axial_mip = tio.SlabProjection("S", slab_thickness=20)
40
        >>> ct_t = axial_mip(ct)
41
        >>> ct_t.plot()
42
43
    .. plot::
44
45
        import torchio as tio
46
        sub = tio.datasets.Slicer('CTChest')
47
        ct = sub.CT_chest
48
        axial_mip = tio.SlabProjection("S", slab_thickness=20)
49
        ct_mip = axial_mip(ct)
50
        sub.add_image(ct_mip, 'CT_MIP')
51
        sub = tio.Clamp(-1000, 1000)(sub)
52
        sub.plot()
53
54
    """
55
    def __init__(
56
            self,
57
            axis: Union[int, str],
58
            slab_thickness: Optional[int] = None,
59
            stride: int = 1,
60
            projection_type: str = 'max',
61
            percentile: Optional[float] = None,
62
            full_slabs_only: bool = True,
63
            **kwargs
64
            ):
65
        super().__init__(**kwargs)
66
        self.args_names = (
67
            'axis', 'slab_thickness', 'stride',
68
            'projection_type', 'percentile', 'full_slabs_only'
69
            )
70
        self.axis = axis
71
        self.slab_thickness = slab_thickness
72
        self.stride = stride
73
        self.projection_fun = self.get_projection_function(projection_type)
74
        self.projection_type = projection_type
75
        self.percentile = self.validate_percentile(percentile, projection_type)
76
        self.full_slabs_only = full_slabs_only
77
78
    @staticmethod
79
    def validate_percentile(percentile, projection_type):
80
        if not projection_type == 'percentile':
81
            return percentile
82
        message = (
83
            "For projection_type='percentile', `percentile` must be a scalar"
84
            f' value in the range [0, 1], but "{percentile}" was passed'
85
        )
86
        if not isinstance(percentile, Real):
87
            raise TypeError(message)
88
        elif 0 <= percentile <= 100:
89
            return percentile / 100
90
        else:
91
            raise ValueError(message)
92
93
    @staticmethod
94
    def get_projection_function(projection_type: str) -> Callable:
95
        arg_to_function = {
96
            'max': 'amax',
97
            'min': 'amin',
98
            'mean': 'mean',
99
            'median': 'median',
100
            'percentile': 'quantile',
101
        }
102
        try:
103
            function_name = arg_to_function[projection_type]
104
        except KeyError:
105
            message = (
106
                f'The projection type must be in {arg_to_function.keys()}, '
107
                f' but {projection_type} was passed'
108
            )
109
            raise ValueError(message)
110
        projection_function = getattr(torch, function_name)
111
        return projection_function
112
113
    def get_num_slabs(self) -> int:
114
        if self.full_slabs_only:
115
            start_index = 0
116
            num_slabs = 0
117
            while start_index + self.slab_thickness <= self.axis_span:
118
                num_slabs += 1
119
                start_index += self.stride
120
        else:
121
            num_slabs = torch.ceil(torch.tensor(self.axis_span) / self.stride)
122
            num_slabs = int(num_slabs.item())
123
        return num_slabs
124
125
    def apply_transform(self, subject: Subject) -> Subject:
126
        for image in self.get_images(subject):
127
            self.apply_projection(image)
128
        return subject
129
130
    def apply_projection(self, image: ScalarImage) -> None:
131
        self.axis_index = image.axis_name_to_index(self.axis)
132
        self.axis_span = image.shape[self.axis_index]
133
        if self.slab_thickness is None:
134
            self.slab_thickness = self.axis_span
135
        elif self.slab_thickness > self.axis_span:
136
            self.slab_thickness = self.axis_span
137
        image.set_data(self.projection(image.data))
138
139
    def projection(self, tensor: torch.Tensor) -> torch.Tensor:
140
        if self.projection_type in ['mean', 'percentile']:
141
            tensor = tensor.to(torch.float)
142
143
        num_slabs = self.get_num_slabs()
144
145
        slabs = []
146
        start_index = 0
147
        end_index = start_index + self.slab_thickness
148
149
        for _ in range(num_slabs):
150
            slab_indices = torch.arange(start_index, end_index)
151
            slab = tensor.index_select(self.axis_index, slab_indices)
152
            if self.projection_type == 'median':
153
                projected, _ = self.projection_fun(
154
                    slab, dim=self.axis_index, keepdim=True)
155
            elif self.projection_type == 'percentile':
156
                projected = self.projection_fun(
157
                    slab, q=self.percentile, dim=self.axis_index,
158
                    keepdim=True)
159
            else:
160
                projected = self.projection_fun(
161
                    slab, dim=self.axis_index, keepdim=True)
162
            slabs.append(projected)
163
            start_index += self.stride
164
            end_index = start_index + self.slab_thickness
165
            if end_index > self.axis_span:
166
                end_index = self.axis_span
167
168
        return torch.cat(slabs, dim=self.axis_index)
169