Passed
Pull Request — master (#625)
by
unknown
01:48
created

Projection.__init__()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 22
rs 9.4
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import torch
2
from torchio.data.image import ScalarImage
3
from ....data.subject import Subject
4
from ...intensity_transform import IntensityTransform
5
from typing import Optional
6
from math import ceil
7
8
9
class Projection(IntensityTransform):
10
    """Project intensities along a given axis, possibly with sliding slabs.
11
12
    Args:
13
        axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
14
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
15
                versions and first letters are also valid, as only the first
16
                letter will be used.
17
        slab_thickness: Thickness of slab projections. In other words, the
18
            number of voxels in the ``axis`` dimension to project across.
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'quantile'``. If ``'quantile'`` is used, ``q`` must also be
27
            supplied.
28
        q: Quantile to use for intensity projections. This argument is required
29
            if ``projection_type`` is ``'quantile'`` and is silently ignored
30
            otherwise.
31
        full_slabs_only: Boolean. Should projections be done only for slabs
32
            that are ``slab_thickness`` thick? Default is ``True``.
33
            If ``False``, some slabs may not be ``slab_thickness`` thick
34
            depending on the size of the image, slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
39
        >>> axial_mips = tio.Projection("S", slab_thickness=20, stride=20)
40
        >>> ct_t = axial_mips(ct)
41
        >>> ct_t.plot()
42
43
    .. plot::
44
45
        import torchio as tio
46
        sub = tio.datasets.Slicer('CTChest')
47
        ct = sub.CT_chest
48
        axial_mips = tio.Projection("S", slab_thickness=20, stride=20)
49
        ct_mips = axial_mips(ct)
50
        sub.add_image(ct_mips, 'MIP')
51
        sub.plot()
52
53
    """
54
    def __init__(
55
            self,
56
            axis: str,
57
            slab_thickness: Optional[int] = None,
58
            stride: Optional[int] = 1,
59
            projection_type: Optional[str] = 'max',
60
            q: Optional[float] = None,
61
            full_slabs_only: Optional[bool] = True,
62
            **kwargs
63
            ):
64
        super().__init__(**kwargs)
65
        self.args_names = (
66
            'axis', 'slab_thickness', 'stride',
67
            'projection_type', 'q', 'full_slabs_only'
68
            )
69
        self.axis = axis
70
        self.slab_thickness = slab_thickness
71
        self.stride = stride
72
        self.projection_type = projection_type
73
        self.q = q
74
        self.full_slabs_only = full_slabs_only
75
        self.projection_fun = self.get_projection_function()
76
77
    def get_projection_function(self):
78
        if self.projection_type == 'max':
79
            projection_fun = torch.amax
80
        elif self.projection_type == 'min':
81
            projection_fun = torch.amin
82
        elif self.projection_type == 'mean':
83
            projection_fun = torch.mean
84
        elif self.projection_type == 'median':
85
            projection_fun = torch.median
86
        elif self.projection_type == 'quantile':
87
            projection_fun = torch.quantile
88
            self.validate_quantile()
89
        else:
90
            message = (
91
                '`projection_type` must be one of "max", "min", "mean",'
92
                ' "median", or "quantile".'
93
                )
94
            raise ValueError(message)
95
        return projection_fun
96
97
    def validate_quantile(self):
98
        message = (
99
            'For `projection_type="quantile"`, `q` must be a scalar value'
100
            f'in the range [0, 1], not {self.q}.'
101
            )
102
        if self.q is None:
103
            raise ValueError(message)
104
        elif 0 <= self.q <= 1:
105
            pass
106
        else:
107
            raise ValueError(message)
108
109
    def apply_transform(self, subject: Subject) -> Subject:
110
        for image in self.get_images(subject):
111
            self.apply_projection(image)
112
        return subject
113
114
    def apply_projection(self, image: ScalarImage) -> None:
115
        self.axis_index = image.axis_name_to_index(self.axis)
116
        self.axis_span = image.shape[self.axis_index]
117
        if self.slab_thickness is None:
118
            self.slab_thickness = self.axis_span
119
        elif self.slab_thickness > self.axis_span:
120
            self.slab_thickness = self.axis_span
121
        image.set_data(self.projection(image.data))
122
123
    def projection(self, tensor: torch.Tensor) -> torch.Tensor:
124
        if self.projection_type in ['mean', 'quantile']:
125
            tensor = tensor.to(torch.float)
126
127
        if self.full_slabs_only:
128
            start_index = 0
129
            num_slabs = 0
130
            while start_index + self.slab_thickness <= self.axis_span:
131
                num_slabs += 1
132
                start_index += self.stride
133
        else:
134
            num_slabs = ceil(self.axis_span / self.stride)
135
136
        slabs = []
137
        start_index = 0
138
        end_index = start_index + self.slab_thickness
139
140
        for _ in range(num_slabs):
141
            slab_indices = torch.tensor(list(range(start_index, end_index)))
142
            slab = tensor.index_select(self.axis_index, slab_indices)
143
            if self.projection_type == 'median':
144
                projected, _ = self.projection_fun(
145
                    slab, dim=self.axis_index, keepdim=True)
146
            elif self.projection_type == 'quantile':
147
                projected = self.projection_fun(
148
                    slab, q=self.q, dim=self.axis_index, keepdim=True)
149
            else:
150
                projected = self.projection_fun(
151
                    slab, dim=self.axis_index, keepdim=True)
152
            slabs.append(projected)
153
            start_index += self.stride
154
            end_index = start_index + self.slab_thickness
155
            if end_index > self.axis_span:
156
                end_index = self.axis_span
157
158
        return torch.cat(slabs, dim=self.axis_index)
159