Passed
Pull Request — master (#625)
by
unknown
01:16
created

SlabProjection.get_num_slabs()   A

Complexity

Conditions 3

Size

Total Lines 11
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 10
dl 0
loc 11
rs 9.9
c 0
b 0
f 0
cc 3
nop 1
1
import torch
2
from torchio.data.image import ScalarImage
3
from ....data.subject import Subject
4
from ...intensity_transform import IntensityTransform
5
from typing import Optional
6
7
8
class SlabProjection(IntensityTransform):
9
    """Project intensities along a given axis, possibly with sliding slabs.
10
11
    Args:
12
        axis: Index for the axis dimension to project across. For 3D images,
13
            possible values are 0, 1, or 2, for the 1st, 2nd, or 3rd spatial
14
            dimension (ignoring the channel dimension).
15
        slab_thickness: Thickness of slab projections. In other words, the
16
            number of voxels in the ``axis`` dimension to project across.
17
            If ``None``, the projection will be done across the entire span of
18
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
19
            1).
20
        stride: Number of voxels to stride along the ``axis`` dimension between
21
            slab projections. Default is 1.
22
        projection_type: Type of intensity projection. Possible inputs are
23
            ``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or
24
            ``'percentile'``. If ``'percentile'`` is used, the ``percentile``
25
            argument must also be supplied.
26
        percentile: Percetile to use for intensity projections. This argument
27
            is required if ``projection_type`` is ``'percentile'`` and is
28
            silently ignored otherwise.
29
        full_slabs_only: Boolean. Should projections be done only for slabs
30
            that are ``slab_thickness`` thick? Default is ``True``.
31
            If ``False``, some slabs may not be ``slab_thickness`` thick
32
            depending on the size of the image, slab thickness, and stride.
33
34
    Example:
35
        >>> import torchio as tio
36
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
37
        >>> axial_mip = tio.SlabProjection("S", slab_thickness=20)
38
        >>> ct_t = axial_mip(ct)
39
        >>> ct_t.plot()
40
41
    .. plot::
42
43
        import torchio as tio
44
        sub = tio.datasets.Slicer('CTChest')
45
        ct = sub.CT_chest
46
        axial_mip = tio.SlabProjection("S", slab_thickness=20)
47
        ct_mip = axial_mip(ct)
48
        sub.add_image(ct_mip, 'CT_MIP')
49
        sub = tio.Clamp(-1000, 1000)(sub)
50
        sub.plot()
51
52
    """
53
    def __init__(
54
            self,
55
            axis: str,
56
            slab_thickness: Optional[int] = None,
57
            stride: int = 1,
58
            projection_type: str = 'max',
59
            percentile: Optional[float] = None,
60
            full_slabs_only: bool = True,
61
            **kwargs
62
            ):
63
        super().__init__(**kwargs)
64
        self.args_names = (
65
            'axis', 'slab_thickness', 'stride',
66
            'projection_type', 'percentile', 'full_slabs_only'
67
            )
68
        self.axis = axis
69
        self.slab_thickness = slab_thickness
70
        self.stride = stride
71
        self.projection_type = projection_type
72
        self.percentile = self.validate_percentile(percentile)
73
        self.full_slabs_only = full_slabs_only
74
        self.projection_fun = self.get_projection_function()
75
76
    def validate_percentile(self, percentile):
77
        if not self.projection_type == 'percentile':
78
            return percentile
79
        message = (
80
            "For projection_type='percentile', `percentile` must be a scalar"
81
            f' value in the range [0, 1], not {percentile}.'
82
            )
83
        if percentile is None:
84
            raise ValueError(message)
85
        elif 0 <= percentile <= 100:
86
            return percentile / 100
87
        else:
88
            raise ValueError(message)
89
90
    def get_projection_function(self):
91
        if self.projection_type == 'max':
92
            projection_fun = torch.amax
93
        elif self.projection_type == 'min':
94
            projection_fun = torch.amin
95
        elif self.projection_type == 'mean':
96
            projection_fun = torch.mean
97
        elif self.projection_type == 'median':
98
            projection_fun = torch.median
99
        elif self.projection_type == 'percentile':
100
            projection_fun = torch.quantile
101
        else:
102
            message = (
103
                '`projection_type` must be one of "max", "min", "mean",'
104
                ' "median", or "percentile".'
105
                )
106
            raise ValueError(message)
107
        return projection_fun
108
109
    def get_num_slabs(self):
110
        if self.full_slabs_only:
111
            start_index = 0
112
            num_slabs = 0
113
            while start_index + self.slab_thickness <= self.axis_span:
114
                num_slabs += 1
115
                start_index += self.stride
116
        else:
117
            num_slabs = torch.ceil(torch.tensor(self.axis_span) / self.stride)
118
            num_slabs = int(num_slabs.item())
119
        return num_slabs
120
121
    def apply_transform(self, subject: Subject) -> Subject:
122
        for image in self.get_images(subject):
123
            self.apply_projection(image)
124
        return subject
125
126
    def apply_projection(self, image: ScalarImage) -> None:
127
        self.axis_index = image.axis_name_to_index(self.axis)
128
        self.axis_span = image.shape[self.axis_index]
129
        if self.slab_thickness is None:
130
            self.slab_thickness = self.axis_span
131
        elif self.slab_thickness > self.axis_span:
132
            self.slab_thickness = self.axis_span
133
        image.set_data(self.projection(image.data))
134
135
    def projection(self, tensor: torch.Tensor) -> torch.Tensor:
136
        if self.projection_type in ['mean', 'percentile']:
137
            tensor = tensor.to(torch.float)
138
139
        num_slabs = self.get_num_slabs()
140
141
        slabs = []
142
        start_index = 0
143
        end_index = start_index + self.slab_thickness
144
145
        for _ in range(num_slabs):
146
            slab_indices = torch.arange(start_index, end_index)
147
            slab = tensor.index_select(self.axis_index, slab_indices)
148
            if self.projection_type == 'median':
149
                projected, _ = self.projection_fun(
150
                    slab, dim=self.axis_index, keepdim=True)
151
            elif self.projection_type == 'percentile':
152
                projected = self.projection_fun(
153
                    slab, q=self.percentile, dim=self.axis_index,
154
                    keepdim=True)
155
            else:
156
                projected = self.projection_fun(
157
                    slab, dim=self.axis_index, keepdim=True)
158
            slabs.append(projected)
159
            start_index += self.stride
160
            end_index = start_index + self.slab_thickness
161
            if end_index > self.axis_span:
162
                end_index = self.axis_span
163
164
        return torch.cat(slabs, dim=self.axis_index)
165