Passed
Pull Request — master (#625)
by
unknown
01:16
created

Projection.__init__()   A

Complexity

Conditions 1

Size

Total Lines 22
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 22
rs 9.4
c 0
b 0
f 0
cc 1
nop 8

How to fix   Many Parameters   

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
import torch
2
from torchio.data.image import ScalarImage
3
from ....data.subject import Subject
4
from ...intensity_transform import IntensityTransform
5
from typing import Optional
6
from math import ceil
7
8
9
class Projection(IntensityTransform):
10
    """Project intensities along a given axis, possibly with sliding slabs.
11
12
    Args:
13
        axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
14
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
15
                versions and first letters are also valid, as only the first
16
                letter will be used.
17
        slab_thickness: Thickness of slab projections. In other words, the
18
            number of voxels in the ``axis`` dimension to project across.
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'quantile'``. If ``'quantile'`` is used, ``q`` must also be
27
            supplied.
28
        q: Quantile to use for intensity projections. This argument is required
29
            if ``projection_type`` is ``'quantile'`` and is silently ignored
30
            otherwise.
31
        full_slabs_only: Boolean. Should projections be done only for slabs
32
            that are ``slab_thickness`` thick? Default is ``True``.
33
            If ``False``, some slabs may not be ``slab_thickness`` thick
34
            depending on the size of the image, slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> sub = tio.datasets.Colin27()
39
        >>> axial_mips = tio.Projection("S", slab_thickness=20)
40
        >>> sub_t = axial_mips(sub)
41
        >>> sub_t.t1.plot()
42
    """
43
    def __init__(
44
            self,
45
            axis: str,
46
            slab_thickness: Optional[int] = None,
47
            stride: Optional[int] = 1,
48
            projection_type: Optional[str] = 'max',
49
            q: Optional[float] = None,
50
            full_slabs_only: Optional[bool] = True,
51
            **kwargs
52
            ):
53
        super().__init__(**kwargs)
54
        self.args_names = (
55
            'axis', 'slab_thickness', 'stride',
56
            'projection_type', 'q', 'full_slabs_only'
57
            )
58
        self.axis = axis
59
        self.slab_thickness = slab_thickness
60
        self.stride = stride
61
        self.projection_type = projection_type
62
        self.q = q
63
        self.full_slabs_only = full_slabs_only
64
        self.projection_fun = self.get_projection_function()
65
66
    def get_projection_function(self):
67
        if self.projection_type == 'max':
68
            projection_fun = torch.amax
69
        elif self.projection_type == 'min':
70
            projection_fun = torch.amin
71
        elif self.projection_type == 'mean':
72
            projection_fun = torch.mean
73
        elif self.projection_type == 'median':
74
            projection_fun = torch.median
75
        elif self.projection_type == 'quantile':
76
            projection_fun = torch.quantile
77
            self.validate_quantile()
78
        else:
79
            message = (
80
                '`projection_type` must be one of "max", "min", "mean",'
81
                ' "median", or "quantile".'
82
                )
83
            raise ValueError(message)
84
        return projection_fun
85
86
    def validate_quantile(self):
87
        message = (
88
            'For `projection_type="quantile"`, `q` must be a scalar value'
89
            f'between 0 and 1, not {self.q}.'
90
            )
91
        if self.q is None:
92
            raise ValueError(message)
93
        elif 0 < self.q < 1:
94
            pass
95
        else:
96
            raise ValueError(message)
97
98
    def apply_transform(self, subject: Subject) -> Subject:
99
        for image in self.get_images(subject):
100
            self.apply_projection(image)
101
        return subject
102
103
    def apply_projection(self, image: ScalarImage) -> None:
104
        self.axis_index = image.axis_name_to_index(self.axis)
105
        self.axis_span = image.shape[self.axis_index]
106
        if self.slab_thickness is None:
107
            self.slab_thickness = self.axis_span
108
        elif self.slab_thickness > self.axis_span:
109
            self.slab_thickness = self.axis_span
110
        image.set_data(self.projection(image.data))
111
112
    def projection(self, tensor: torch.Tensor) -> torch.Tensor:
113
        if self.full_slabs_only:
114
            start_index = 0
115
            num_slabs = 0
116
            while start_index + self.slab_thickness <= self.axis_span:
117
                num_slabs += 1
118
                start_index += self.stride
119
        else:
120
            num_slabs = ceil(self.axis_span / self.stride)
121
122
        slabs = []
123
        start_index = 0
124
        end_index = start_index + self.slab_thickness
125
126
        for _ in range(num_slabs):
127
            slab_indices = torch.tensor(list(range(start_index, end_index)))
128
            slab = tensor.index_select(self.axis_index, slab_indices)
129
            if self.projection_type == 'median':
130
                projected, _ = self.projection_fun(
131
                    slab, dim=self.axis_index, keepdim=True)
132
            elif self.projection_type == 'quantile':
133
                projected = self.projection_fun(
134
                    slab, q=self.q, dim=self.axis_index, keepdim=True)
135
            else:
136
                projected = self.projection_fun(
137
                    slab, dim=self.axis_index, keepdim=True)
138
            slabs.append(projected)
139
            start_index += self.stride
140
            end_index = start_index + self.slab_thickness
141
            if end_index > self.axis_span:
142
                end_index = self.axis_span
143
144
        return torch.cat(slabs, dim=self.axis_index)
145