Passed
Pull Request — master (#625)
by
unknown
01:17
created

Projection.__init__()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 22
rs 9.4
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import torch
2
from torchio.data.image import ScalarImage
3
from ....data.subject import Subject
4
from ...intensity_transform import IntensityTransform
5
from typing import Optional
6
from math import ceil
7
8
9
class Projection(IntensityTransform):
10
    """Project intensities along a given axis, possibly with sliding slabs.
11
12
    Args:
13
        axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
14
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
15
                versions and first letters are also valid, as only the first
16
                letter will be used.
17
        slab_thickness: Thickness of slab projections. In other words, the
18
            number of voxels in the ``axis`` dimension to project across.
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'quantile'``. If ``'quantile'`` is used, ``q`` must also be
27
            supplied.
28
        q: Quantile to use for intensity projections. This argument is required
29
            if ``projection_type`` is ``'quantile'`` and is silently ignored
30
            otherwise.
31
        full_slabs_only: Boolean. Should projections be done only for slabs
32
            that are ``slab_thickness`` thick? Default is ``True``.
33
            If ``False``, some slabs may not be ``slab_thickness`` thick
34
            depending on the size of the image, slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> sub = tio.datasets.Colin27()
39
        >>> axial_mips = tio.Projection("S", slab_thickness=20)
40
        >>> sub_t = axial_mips(sub)
41
        >>> sub_t.t1.plot()
42
    """
43
    def __init__(
44
            self,
45
            axis: str,
46
            slab_thickness: Optional[int] = None,
47
            stride: Optional[int] = 1,
48
            projection_type: Optional[str] = 'max',
49
            q: Optional[float] = None,
50
            full_slabs_only: Optional[bool] = True,
51
            **kwargs
52
            ):
53
        super().__init__(**kwargs)
54
        self.args_names = (
55
            'axis', 'slab_thickness', 'stride',
56
            'projection_type', 'q', 'full_slabs_only'
57
            )
58
        self.axis = axis
59
        self.slab_thickness = slab_thickness
60
        self.stride = stride
61
        self.projection_type = projection_type
62
        self.q = q
63
        self.full_slabs_only = full_slabs_only
64
        self.projection_fun = self.get_projection_function()
65
66
    def get_projection_function(self):
67
        if self.projection_type == 'max':
68
            projection_fun = torch.amax
69
        elif self.projection_type == 'min':
70
            projection_fun = torch.amin
71
        elif self.projection_type == 'mean':
72
            projection_fun = torch.mean
73
        elif self.projection_type == 'median':
74
            projection_fun = torch.median
75
        elif self.projection_type == 'quantile':
76
            projection_fun = torch.quantile
77
            self.validate_quantile()
78
        else:
79
            message = (
80
                '`projection_type` must be one of "max", "min", "mean",'
81
                ' "median", or "quantile".'
82
                )
83
            raise ValueError(message)
84
        return projection_fun
85
86
    def validate_quantile(self):
87
        message = (
88
            'For `projection_type="quantile"`, `q` must be a scalar value'
89
            f'in the range [0, 1], not {self.q}.'
90
            )
91
        if self.q is None:
92
            raise ValueError(message)
93
        elif 0 <= self.q <= 1:
94
            pass
95
        else:
96
            raise ValueError(message)
97
98
    def apply_transform(self, subject: Subject) -> Subject:
99
        for image in self.get_images(subject):
100
            self.apply_projection(image)
101
        return subject
102
103
    def apply_projection(self, image: ScalarImage) -> None:
104
        self.axis_index = image.axis_name_to_index(self.axis)
105
        self.axis_span = image.shape[self.axis_index]
106
        if self.slab_thickness is None:
107
            self.slab_thickness = self.axis_span
108
        elif self.slab_thickness > self.axis_span:
109
            self.slab_thickness = self.axis_span
110
        image.set_data(self.projection(image.data))
111
112
    def projection(self, tensor: torch.Tensor) -> torch.Tensor:
113
        if self.projection_type in ['mean', 'quantile']:
114
            tensor = tensor.to(torch.float)
115
116
        if self.full_slabs_only:
117
            start_index = 0
118
            num_slabs = 0
119
            while start_index + self.slab_thickness <= self.axis_span:
120
                num_slabs += 1
121
                start_index += self.stride
122
        else:
123
            num_slabs = ceil(self.axis_span / self.stride)
124
125
        slabs = []
126
        start_index = 0
127
        end_index = start_index + self.slab_thickness
128
129
        for _ in range(num_slabs):
130
            slab_indices = torch.tensor(list(range(start_index, end_index)))
131
            slab = tensor.index_select(self.axis_index, slab_indices)
132
            if self.projection_type == 'median':
133
                projected, _ = self.projection_fun(
134
                    slab, dim=self.axis_index, keepdim=True)
135
            elif self.projection_type == 'quantile':
136
                projected = self.projection_fun(
137
                    slab, q=self.q, dim=self.axis_index, keepdim=True)
138
            else:
139
                projected = self.projection_fun(
140
                    slab, dim=self.axis_index, keepdim=True)
141
            slabs.append(projected)
142
            start_index += self.stride
143
            end_index = start_index + self.slab_thickness
144
            if end_index > self.axis_span:
145
                end_index = self.axis_span
146
147
        return torch.cat(slabs, dim=self.axis_index)
148