Passed
Pull Request — master (#625)
by
unknown
01:17
created

Projection.projection()   B

Complexity

Conditions 8

Size

Total Lines 36
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 30
dl 0
loc 36
rs 7.2933
c 0
b 0
f 0
cc 8
nop 2
1
import torch
2
from torchio.data.image import ScalarImage
3
from ....data.subject import Subject
4
from ...intensity_transform import IntensityTransform
5
from typing import Optional
6
from math import ceil
7
8
9
class Projection(IntensityTransform):
10
    """Project intensities along a given axis, possibly with sliding slabs.
11
12
    Args:
13
        axis: Possible inputs are ``'Left'``, ``'Right'``, ``'Anterior'``,
14
                ``'Posterior'``, ``'Inferior'``, ``'Superior'``. Lower-case
15
                versions and first letters are also valid, as only the first
16
                letter will be used.
17
        slab_thickness: Thickness of slab projections. In other words, the
18
            number of voxels in the ``axis`` dimension to project across.
19
            If ``None``, the projection will be done across the entire span of
20
            the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to
21
            1).
22
        stride: Number of voxels to stride along the ``axis`` dimension between
23
            slab projections.
24
        projection_type: Type of intensity projection. Possible inputs are
25
            ``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or
26
            ``'quantile'``. If ``'quantile'`` is used, ``q`` must also be
27
            supplied.
28
        q: Quantile to use for intensity projections. This argument is required
29
            if ``projection_type`` is ``'quantile'`` and is silently ignored
30
            otherwise.
31
        full_slabs_only: Boolean. Should projections be done only for slabs
32
            that are ``slab_thickness`` thick? Default is ``True``.
33
            If ``False``, some slabs may not be ``slab_thickness`` thick
34
            depending on the size of the image, slab thickness, and stride.
35
36
    Example:
37
        >>> import torchio as tio
38
        >>> sub = tio.datasets.Colin27()
39
        >>> axial_mips = tio.Projection("S", slab_thickness=20)
40
        >>> sub_t = axial_mips(sub)
41
        >>> sub_t.t1.plot()
42
    """
43
    def __init__(
44
            self,
45
            axis: str,
46
            slab_thickness: Optional[int] = None,
47
            stride: Optional[int] = 1,
48
            projection_type: Optional[str] = 'max',
49
            q: Optional[float] = None,
50
            full_slabs_only: Optional[bool] = True,
51
            **kwargs
52
            ):
53
        super().__init__(**kwargs)
54
        self.args_names = (
55
            'axis', 'slab_thickness', 'stride',
56
            'projection_type', 'q', 'full_slabs_only'
57
            )
58
        self.axis = axis
59
        self.slab_thickness = slab_thickness
60
        self.stride = stride
61
        self.projection_type = projection_type
62
        self.q = q
63
        self.full_slabs_only = full_slabs_only
64
        self.projection_fun = self.get_projection_function()
65
66
    def get_projection_function(self):
67
        if self.projection_type == 'max':
68
            projection_fun = torch.amax
69
        elif self.projection_type == 'min':
70
            projection_fun = torch.amin
71
        elif self.projection_type == 'mean':
72
            projection_fun = torch.mean
73
        elif self.projection_type == 'median':
74
            projection_fun = torch.median
75
        elif self.projection_type == 'quantile':
76
            projection_fun = torch.quantile
77
            self.validate_quantile()
78
        else:
79
            message = (
80
                '`projection_type` must be one of "max", "min", "mean",'
81
                ' "median", or "quantile".'
82
                )
83
            raise ValueError(message)
84
        return projection_fun
85
86
    def validate_quantile(self):
87
        message = (
88
            'For `projection_type="quantile"`, `q` must be a scalar value'
89
            f'in the range [0, 1], not {self.q}.'
90
            )
91
        if self.q is None:
92
            raise ValueError(message)
93
        elif 0 <= self.q <= 1:
94
            pass
95
        else:
96
            raise ValueError(message)
97
98
    def apply_transform(self, subject: Subject) -> Subject:
99
        for image in self.get_images(subject):
100
            self.apply_projection(image)
101
        return subject
102
103
    def apply_projection(self, image: ScalarImage) -> None:
104
        self.axis_index = image.axis_name_to_index(self.axis)
105
        self.axis_span = image.shape[self.axis_index]
106
        if self.slab_thickness is None:
107
            self.slab_thickness = self.axis_span
108
        elif self.slab_thickness > self.axis_span:
109
            self.slab_thickness = self.axis_span
110
        image.set_data(self.projection(image.data))
111
112
    def projection(self, tensor: torch.Tensor) -> torch.Tensor:
113
        if self.projection_type in ['mean', 'quantile']:
114
            tensor = tensor.to(torch.float)
115
116
        if self.full_slabs_only:
117
            start_index = 0
118
            num_slabs = 0
119
            while start_index + self.slab_thickness <= self.axis_span:
120
                num_slabs += 1
121
                start_index += self.stride
122
        else:
123
            num_slabs = ceil(self.axis_span / self.stride)
124
125
        slabs = []
126
        start_index = 0
127
        end_index = start_index + self.slab_thickness
128
129
        for _ in range(num_slabs):
130
            slab_indices = torch.tensor(list(range(start_index, end_index)))
131
            slab = tensor.index_select(self.axis_index, slab_indices)
132
            if self.projection_type == 'median':
133
                projected, _ = self.projection_fun(
134
                    slab, dim=self.axis_index, keepdim=True)
135
            elif self.projection_type == 'quantile':
136
                projected = self.projection_fun(
137
                    slab, q=self.q, dim=self.axis_index, keepdim=True)
138
            else:
139
                projected = self.projection_fun(
140
                    slab, dim=self.axis_index, keepdim=True)
141
            slabs.append(projected)
142
            start_index += self.stride
143
            end_index = start_index + self.slab_thickness
144
            if end_index > self.axis_span:
145
                end_index = self.axis_span
146
147
        return torch.cat(slabs, dim=self.axis_index)
148