Passed
Pull Request — master (#625)
by
unknown
01:22
created

torchio.transforms.preprocessing.intensity.projection   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 159
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 23
eloc 94
dl 0
loc 159
rs 10
c 0
b 0
f 0

6 Methods

Rating   Name   Duplication   Size   Complexity  
B Projection.projection() 0 36 8
A Projection.validate_quantile() 0 11 3
A Projection.apply_transform() 0 4 2
A Projection.__init__() 0 22 1
B Projection.get_projection_function() 0 19 6
A Projection.apply_projection() 0 8 3
1
import torch
2
from torchio.data.image import ScalarImage
3
from ....data.subject import Subject
4
from ...intensity_transform import IntensityTransform
5
from typing import Optional
6
from math import ceil
7
8
9
class Projection(IntensityTransform):
10
    """Project intensities along a given axis, possibly with sliding slabs.
11
12
    Args:
13
        axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
14
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
15
                versions and first letters are also valid, as only the first
16
                letter will be used.
17
        slab_thickness: Thickness of slab projections. In other words, the
18
            number of voxels in the ``axis`` dimension to project across.
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'quantile'``. If ``'quantile'`` is used, ``q`` must also be
27
            supplied.
28
        q: Quantile to use for intensity projections. This argument is required
29
            if ``projection_type`` is ``'quantile'`` and is silently ignored
30
            otherwise.
31
        full_slabs_only: Boolean. Should projections be done only for slabs
32
            that are ``slab_thickness`` thick? Default is ``True``.
33
            If ``False``, some slabs may not be ``slab_thickness`` thick
34
            depending on the size of the image, slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> ct = tio.datasets.Slicer('CTChest').CT_chest
39
        >>> axial_mips = tio.Projection("S", slab_thickness=20, stride=20)
40
        >>> ct_t = axial_mips(ct)
41
        >>> ct_t.plot()
42
43
    .. plot::
44
45
        import torchio as tio
46
        sub = tio.datasets.Slicer('CTChest')
47
        ct = sub.CT_chest
48
        axial_mips = tio.Projection("S", slab_thickness=20, stride=20)
49
        ct_mips= axial_mips(ct)
50
        sub.add_image(ct_mips, 'MIP')
51
        sub.plot()
52
53
    """
54
    def __init__(
55
            self,
56
            axis: str,
57
            slab_thickness: Optional[int] = None,
58
            stride: Optional[int] = 1,
59
            projection_type: Optional[str] = 'max',
60
            q: Optional[float] = None,
61
            full_slabs_only: Optional[bool] = True,
62
            **kwargs
63
            ):
64
        super().__init__(**kwargs)
65
        self.args_names = (
66
            'axis', 'slab_thickness', 'stride',
67
            'projection_type', 'q', 'full_slabs_only'
68
            )
69
        self.axis = axis
70
        self.slab_thickness = slab_thickness
71
        self.stride = stride
72
        self.projection_type = projection_type
73
        self.q = q
74
        self.full_slabs_only = full_slabs_only
75
        self.projection_fun = self.get_projection_function()
76
77
    def get_projection_function(self):
78
        if self.projection_type == 'max':
79
            projection_fun = torch.amax
80
        elif self.projection_type == 'min':
81
            projection_fun = torch.amin
82
        elif self.projection_type == 'mean':
83
            projection_fun = torch.mean
84
        elif self.projection_type == 'median':
85
            projection_fun = torch.median
86
        elif self.projection_type == 'quantile':
87
            projection_fun = torch.quantile
88
            self.validate_quantile()
89
        else:
90
            message = (
91
                '`projection_type` must be one of "max", "min", "mean",'
92
                ' "median", or "quantile".'
93
                )
94
            raise ValueError(message)
95
        return projection_fun
96
97
    def validate_quantile(self):
98
        message = (
99
            'For `projection_type="quantile"`, `q` must be a scalar value'
100
            f'in the range [0, 1], not {self.q}.'
101
            )
102
        if self.q is None:
103
            raise ValueError(message)
104
        elif 0 <= self.q <= 1:
105
            pass
106
        else:
107
            raise ValueError(message)
108
109
    def apply_transform(self, subject: Subject) -> Subject:
110
        for image in self.get_images(subject):
111
            self.apply_projection(image)
112
        return subject
113
114
    def apply_projection(self, image: ScalarImage) -> None:
115
        self.axis_index = image.axis_name_to_index(self.axis)
116
        self.axis_span = image.shape[self.axis_index]
117
        if self.slab_thickness is None:
118
            self.slab_thickness = self.axis_span
119
        elif self.slab_thickness > self.axis_span:
120
            self.slab_thickness = self.axis_span
121
        image.set_data(self.projection(image.data))
122
123
    def projection(self, tensor: torch.Tensor) -> torch.Tensor:
124
        if self.projection_type in ['mean', 'quantile']:
125
            tensor = tensor.to(torch.float)
126
127
        if self.full_slabs_only:
128
            start_index = 0
129
            num_slabs = 0
130
            while start_index + self.slab_thickness <= self.axis_span:
131
                num_slabs += 1
132
                start_index += self.stride
133
        else:
134
            num_slabs = ceil(self.axis_span / self.stride)
135
136
        slabs = []
137
        start_index = 0
138
        end_index = start_index + self.slab_thickness
139
140
        for _ in range(num_slabs):
141
            slab_indices = torch.tensor(list(range(start_index, end_index)))
142
            slab = tensor.index_select(self.axis_index, slab_indices)
143
            if self.projection_type == 'median':
144
                projected, _ = self.projection_fun(
145
                    slab, dim=self.axis_index, keepdim=True)
146
            elif self.projection_type == 'quantile':
147
                projected = self.projection_fun(
148
                    slab, q=self.q, dim=self.axis_index, keepdim=True)
149
            else:
150
                projected = self.projection_fun(
151
                    slab, dim=self.axis_index, keepdim=True)
152
            slabs.append(projected)
153
            start_index += self.stride
154
            end_index = start_index + self.slab_thickness
155
            if end_index > self.axis_span:
156
                end_index = self.axis_span
157
158
        return torch.cat(slabs, dim=self.axis_index)
159