1
|
|
|
from numbers import Real |
2
|
|
|
from typing import Union, Optional, Callable |
3
|
|
|
|
4
|
|
|
import torch |
5
|
|
|
from torchio.data.image import ScalarImage |
6
|
|
|
from ....data.subject import Subject |
7
|
|
|
from ...intensity_transform import IntensityTransform |
8
|
|
|
|
9
|
|
|
|
10
|
|
|
class SlabProjection(IntensityTransform): |
11
|
|
|
"""Project intensities along a given axis, possibly with sliding slabs. |
12
|
|
|
|
13
|
|
|
Args: |
14
|
|
|
axis: Index for the spatial dimension to project across. |
15
|
|
|
See :class:`~.torchio.RandomFlip` for information on the accepted |
16
|
|
|
types. |
17
|
|
|
slab_thickness: Thickness of slab projections. In other words, the |
18
|
|
|
number of voxels in the ``axis`` dimension to project across. |
19
|
|
|
If ``None``, the projection will be done across the entire span of |
20
|
|
|
the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to |
21
|
|
|
1). |
22
|
|
|
stride: Number of voxels to stride along the ``axis`` dimension between |
23
|
|
|
slab projections. Default is 1. |
24
|
|
|
projection_type: Type of intensity projection. Possible inputs are |
25
|
|
|
``'max'`` (the default), ``'min'``, ``'mean'``, ``'median'``, or |
26
|
|
|
``'percentile'``. If ``'percentile'`` is used, the ``percentile`` |
27
|
|
|
argument must also be supplied. |
28
|
|
|
percentile: Percetile to use for intensity projections. This argument |
29
|
|
|
is required if ``projection_type`` is ``'percentile'`` and is |
30
|
|
|
silently ignored otherwise. |
31
|
|
|
full_slabs_only: Boolean. Should projections be done only for slabs |
32
|
|
|
that are ``slab_thickness`` thick? Default is ``True``. |
33
|
|
|
If ``False``, some slabs may not be ``slab_thickness`` thick |
34
|
|
|
depending on the size of the image, slab thickness, and stride. |
35
|
|
|
|
36
|
|
|
Example: |
37
|
|
|
>>> import torchio as tio |
38
|
|
|
>>> ct = tio.datasets.Slicer('CTChest').CT_chest |
39
|
|
|
>>> axial_mip = tio.SlabProjection("S", slab_thickness=20) |
40
|
|
|
>>> ct_t = axial_mip(ct) |
41
|
|
|
>>> ct_t.plot() |
42
|
|
|
|
43
|
|
|
.. plot:: |
44
|
|
|
|
45
|
|
|
import torchio as tio |
46
|
|
|
sub = tio.datasets.Slicer('CTChest') |
47
|
|
|
ct = sub.CT_chest |
48
|
|
|
axial_mip = tio.SlabProjection("S", slab_thickness=20) |
49
|
|
|
ct_mip = axial_mip(ct) |
50
|
|
|
sub.add_image(ct_mip, 'CT_MIP') |
51
|
|
|
sub = tio.Clamp(-1000, 1000)(sub) |
52
|
|
|
sub.plot() |
53
|
|
|
|
54
|
|
|
""" |
55
|
|
|
def __init__( |
56
|
|
|
self, |
57
|
|
|
axis: Union[int, str], |
58
|
|
|
slab_thickness: Optional[int] = None, |
59
|
|
|
stride: int = 1, |
60
|
|
|
projection_type: str = 'max', |
61
|
|
|
percentile: Optional[float] = None, |
62
|
|
|
full_slabs_only: bool = True, |
63
|
|
|
**kwargs |
64
|
|
|
): |
65
|
|
|
super().__init__(**kwargs) |
66
|
|
|
self.args_names = ( |
67
|
|
|
'axis', 'slab_thickness', 'stride', |
68
|
|
|
'projection_type', 'percentile', 'full_slabs_only' |
69
|
|
|
) |
70
|
|
|
self.axis = axis |
71
|
|
|
self.slab_thickness = slab_thickness |
72
|
|
|
self.stride = stride |
73
|
|
|
self.projection_fun = self.get_projection_function(projection_type) |
74
|
|
|
self.projection_type = projection_type |
75
|
|
|
self.percentile = self.validate_percentile(percentile, projection_type) |
76
|
|
|
self.full_slabs_only = full_slabs_only |
77
|
|
|
|
78
|
|
|
@staticmethod |
79
|
|
|
def validate_percentile(percentile, projection_type): |
80
|
|
|
if not projection_type == 'percentile': |
81
|
|
|
return percentile |
82
|
|
|
message = ( |
83
|
|
|
"For projection_type='percentile', `percentile` must be a scalar" |
84
|
|
|
f' value in the range [0, 1], but "{percentile}" was passed' |
85
|
|
|
) |
86
|
|
|
if not isinstance(percentile, Real): |
87
|
|
|
raise TypeError(message) |
88
|
|
|
elif 0 <= percentile <= 100: |
89
|
|
|
return percentile / 100 |
90
|
|
|
else: |
91
|
|
|
raise ValueError(message) |
92
|
|
|
|
93
|
|
|
@staticmethod |
94
|
|
|
def get_projection_function(projection_type: str) -> Callable: |
95
|
|
|
arg_to_function = { |
96
|
|
|
'max': 'amax', |
97
|
|
|
'min': 'amin', |
98
|
|
|
'mean': 'mean', |
99
|
|
|
'median': 'median', |
100
|
|
|
'percentile': 'quantile', |
101
|
|
|
} |
102
|
|
|
try: |
103
|
|
|
function_name = arg_to_function[projection_type] |
104
|
|
|
except KeyError: |
105
|
|
|
message = ( |
106
|
|
|
f'The projection type must be in {arg_to_function.keys()}, ' |
107
|
|
|
f' but {projection_type} was passed' |
108
|
|
|
) |
109
|
|
|
raise ValueError(message) |
110
|
|
|
projection_function = getattr(torch, function_name) |
111
|
|
|
return projection_function |
112
|
|
|
|
113
|
|
|
def get_num_slabs(self) -> int: |
114
|
|
|
if self.full_slabs_only: |
115
|
|
|
start_index = 0 |
116
|
|
|
num_slabs = 0 |
117
|
|
|
while start_index + self.slab_thickness <= self.axis_span: |
118
|
|
|
num_slabs += 1 |
119
|
|
|
start_index += self.stride |
120
|
|
|
else: |
121
|
|
|
num_slabs = torch.ceil(torch.tensor(self.axis_span) / self.stride) |
122
|
|
|
num_slabs = int(num_slabs.item()) |
123
|
|
|
return num_slabs |
124
|
|
|
|
125
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
126
|
|
|
for image in self.get_images(subject): |
127
|
|
|
self.apply_projection(image) |
128
|
|
|
return subject |
129
|
|
|
|
130
|
|
|
def apply_projection(self, image: ScalarImage) -> None: |
131
|
|
|
self.axis_index = image.axis_name_to_index(self.axis) |
132
|
|
|
self.axis_span = image.shape[self.axis_index] |
133
|
|
|
if self.slab_thickness is None: |
134
|
|
|
self.slab_thickness = self.axis_span |
135
|
|
|
elif self.slab_thickness > self.axis_span: |
136
|
|
|
self.slab_thickness = self.axis_span |
137
|
|
|
image.set_data(self.projection(image.data)) |
138
|
|
|
|
139
|
|
|
def projection(self, tensor: torch.Tensor) -> torch.Tensor: |
140
|
|
|
if self.projection_type in ['mean', 'percentile']: |
141
|
|
|
tensor = tensor.to(torch.float) |
142
|
|
|
|
143
|
|
|
num_slabs = self.get_num_slabs() |
144
|
|
|
|
145
|
|
|
slabs = [] |
146
|
|
|
start_index = 0 |
147
|
|
|
end_index = start_index + self.slab_thickness |
148
|
|
|
|
149
|
|
|
for _ in range(num_slabs): |
150
|
|
|
slab_indices = torch.arange(start_index, end_index) |
151
|
|
|
slab = tensor.index_select(self.axis_index, slab_indices) |
152
|
|
|
if self.projection_type == 'median': |
153
|
|
|
projected, _ = self.projection_fun( |
154
|
|
|
slab, dim=self.axis_index, keepdim=True) |
155
|
|
|
elif self.projection_type == 'percentile': |
156
|
|
|
projected = self.projection_fun( |
157
|
|
|
slab, q=self.percentile, dim=self.axis_index, |
158
|
|
|
keepdim=True) |
159
|
|
|
else: |
160
|
|
|
projected = self.projection_fun( |
161
|
|
|
slab, dim=self.axis_index, keepdim=True) |
162
|
|
|
slabs.append(projected) |
163
|
|
|
start_index += self.stride |
164
|
|
|
end_index = start_index + self.slab_thickness |
165
|
|
|
if end_index > self.axis_span: |
166
|
|
|
end_index = self.axis_span |
167
|
|
|
|
168
|
|
|
return torch.cat(slabs, dim=self.axis_index) |
169
|
|
|
|