1
|
|
|
from numbers import Real |
2
|
|
|
from typing import Union, Optional, Callable |
3
|
|
|
|
4
|
|
|
import torch |
5
|
|
|
from torchio.data.image import ScalarImage |
6
|
|
|
from ....data.subject import Subject |
7
|
|
|
from ...intensity_transform import IntensityTransform |
8
|
|
|
|
9
|
|
|
|
10
|
|
|
class SlabProjection(IntensityTransform): |
11
|
|
|
"""Project intensities along a given axis, possibly with sliding slabs. |
12
|
|
|
|
13
|
|
|
Args: |
14
|
|
|
axis: Index for the spatial dimension to project across. |
15
|
|
|
See :class:`~.torchio.RandomFlip` for information on the accepted |
16
|
|
|
options. |
17
|
|
|
slab_thickness: Thickness of slab projections (number of voxels along |
18
|
|
|
the ``axis`` dimension to project across). |
19
|
|
|
If ``None``, the projection will be done across the entire span of |
20
|
|
|
the ``axis`` dimension (i.e. ``axis`` dimension will be reduced to |
21
|
|
|
1). |
22
|
|
|
stride: Number of voxels to stride along the ``axis`` dimension between |
23
|
|
|
slab projections. Default is 1. |
24
|
|
|
projection_type: Type of intensity projection. Possible inputs are |
25
|
|
|
``'max'``, ``'min'``, ``'mean'``, ``'median'``, or |
26
|
|
|
``'percentile'``. If ``'percentile'`` is used, the ``percentile`` |
27
|
|
|
argument must also be supplied. |
28
|
|
|
percentile: Percentile to use for intensity projections. This argument |
29
|
|
|
is required if ``projection_type`` is ``'percentile'`` and is |
30
|
|
|
silently ignored otherwise. |
31
|
|
|
full_slabs_only: If ``True``, projections will be done only for slabs |
32
|
|
|
that are ``slab_thickness`` thick. If ``False``, some slabs may not |
33
|
|
|
be ``slab_thickness`` thick depending on the size of the image, |
34
|
|
|
slab thickness, and stride. |
35
|
|
|
|
36
|
|
|
Example: |
37
|
|
|
>>> import torchio as tio |
38
|
|
|
>>> ct = tio.datasets.Slicer('CTChest').CT_chest |
39
|
|
|
>>> axial_mip = tio.SlabProjection("S", slab_thickness=20) |
40
|
|
|
>>> ct_t = axial_mip(ct) |
41
|
|
|
>>> ct_t.plot() |
42
|
|
|
|
43
|
|
|
.. plot:: |
44
|
|
|
|
45
|
|
|
import torchio as tio |
46
|
|
|
sub = tio.datasets.Slicer('CTChest') |
47
|
|
|
ct = sub.CT_chest |
48
|
|
|
axial_mip = tio.SlabProjection("S", slab_thickness=20) |
49
|
|
|
ct_mip = axial_mip(ct) |
50
|
|
|
sub.add_image(ct_mip, 'CT_MIP') |
51
|
|
|
sub = tio.Clamp(-1000, 1000)(sub) |
52
|
|
|
sub.plot() |
53
|
|
|
|
54
|
|
|
""" |
55
|
|
|
def __init__( |
56
|
|
|
self, |
57
|
|
|
axis: Union[int, str], |
58
|
|
|
slab_thickness: Optional[int] = None, |
59
|
|
|
stride: int = 1, |
60
|
|
|
projection_type: str = 'max', |
61
|
|
|
percentile: Optional[float] = None, |
62
|
|
|
full_slabs_only: bool = True, |
63
|
|
|
**kwargs |
64
|
|
|
): |
65
|
|
|
super().__init__(**kwargs) |
66
|
|
|
self.args_names = ( |
67
|
|
|
'axis', 'slab_thickness', 'stride', |
68
|
|
|
'projection_type', 'percentile', 'full_slabs_only' |
69
|
|
|
) |
70
|
|
|
self.axis = self.parse_axes(axis)[0] |
71
|
|
|
self.slab_thickness = slab_thickness |
72
|
|
|
self.stride = stride |
73
|
|
|
self.projection_fun = self.get_projection_function(projection_type) |
74
|
|
|
self.projection_type = projection_type |
75
|
|
|
self.percentile = self.validate_percentile(percentile, projection_type) |
76
|
|
|
self.full_slabs_only = full_slabs_only |
77
|
|
|
|
78
|
|
|
@staticmethod |
79
|
|
|
def validate_percentile(percentile, projection_type): |
80
|
|
|
if not projection_type == 'percentile': |
81
|
|
|
return percentile |
82
|
|
|
message = ( |
83
|
|
|
"For projection_type='percentile', `percentile` must be a scalar" |
84
|
|
|
f' value in the range [0, 1], but "{percentile}" was passed' |
85
|
|
|
) |
86
|
|
|
if not isinstance(percentile, Real): |
87
|
|
|
raise TypeError(message) |
88
|
|
|
elif 0 <= percentile <= 100: |
89
|
|
|
return percentile / 100 |
90
|
|
|
else: |
91
|
|
|
raise ValueError(message) |
92
|
|
|
|
93
|
|
|
@staticmethod |
94
|
|
|
def get_projection_function(projection_type: str) -> Callable: |
95
|
|
|
arg_to_function = { |
96
|
|
|
'max': 'amax', |
97
|
|
|
'min': 'amin', |
98
|
|
|
'mean': 'mean', |
99
|
|
|
'median': 'median', |
100
|
|
|
'percentile': 'quantile', |
101
|
|
|
} |
102
|
|
|
try: |
103
|
|
|
function_name = arg_to_function[projection_type] |
104
|
|
|
except KeyError: |
105
|
|
|
message = ( |
106
|
|
|
f'The projection type must be in {arg_to_function.keys()}, ' |
107
|
|
|
f' but {projection_type} was passed' |
108
|
|
|
) |
109
|
|
|
raise ValueError(message) |
110
|
|
|
projection_function = getattr(torch, function_name) |
111
|
|
|
return projection_function |
112
|
|
|
|
113
|
|
|
def get_num_slabs(self, axis_span: int) -> int: |
114
|
|
|
if self.full_slabs_only: |
115
|
|
|
start_index = 0 |
116
|
|
|
num_slabs = 0 |
117
|
|
|
while start_index + self.slab_thickness <= axis_span: |
118
|
|
|
num_slabs += 1 |
119
|
|
|
start_index += self.stride |
120
|
|
|
else: |
121
|
|
|
num_slabs = torch.ceil(torch.tensor(axis_span) / self.stride) |
122
|
|
|
num_slabs = int(num_slabs.item()) |
123
|
|
|
return num_slabs |
124
|
|
|
|
125
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
126
|
|
|
spatial_axis_index = self.ensure_axes_indices(subject, [self.axis])[0] |
127
|
|
|
tensor_axis_index = spatial_axis_index + 1 # channels is 0 |
128
|
|
|
for image in self.get_images(subject): |
129
|
|
|
self.apply_projection(image, tensor_axis_index) |
130
|
|
|
return subject |
131
|
|
|
|
132
|
|
|
def apply_projection(self, image: ScalarImage, axis_index: int) -> None: |
133
|
|
|
axis_span = image.shape[axis_index] |
134
|
|
|
if self.slab_thickness is None: |
135
|
|
|
self.slab_thickness = axis_span |
136
|
|
|
elif self.slab_thickness > axis_span: |
137
|
|
|
self.slab_thickness = axis_span |
138
|
|
|
image.set_data(self.projection(image.data, axis_index, axis_span)) |
139
|
|
|
|
140
|
|
|
def projection( |
141
|
|
|
self, |
142
|
|
|
tensor: torch.Tensor, |
143
|
|
|
axis_index: int, |
144
|
|
|
axis_span: int, |
145
|
|
|
) -> torch.Tensor: |
146
|
|
|
if self.projection_type in ['mean', 'percentile']: |
147
|
|
|
tensor = tensor.to(torch.float) |
148
|
|
|
|
149
|
|
|
num_slabs = self.get_num_slabs(axis_span) |
150
|
|
|
|
151
|
|
|
slabs = [] |
152
|
|
|
start_index = 0 |
153
|
|
|
end_index = start_index + self.slab_thickness |
154
|
|
|
|
155
|
|
|
for _ in range(num_slabs): |
156
|
|
|
slab_indices = torch.arange(start_index, end_index) |
157
|
|
|
slab = tensor.index_select(axis_index, slab_indices) |
158
|
|
|
if self.projection_type == 'median': |
159
|
|
|
projected, _ = self.projection_fun( |
160
|
|
|
slab, dim=axis_index, keepdim=True) |
161
|
|
|
elif self.projection_type == 'percentile': |
162
|
|
|
projected = self.projection_fun( |
163
|
|
|
slab, q=self.percentile, dim=axis_index, |
164
|
|
|
keepdim=True) |
165
|
|
|
else: |
166
|
|
|
projected = self.projection_fun( |
167
|
|
|
slab, dim=axis_index, keepdim=True) |
168
|
|
|
slabs.append(projected) |
169
|
|
|
start_index += self.stride |
170
|
|
|
end_index = start_index + self.slab_thickness |
171
|
|
|
if end_index > axis_span: |
172
|
|
|
end_index = axis_span |
173
|
|
|
|
174
|
|
|
return torch.cat(slabs, dim=axis_index) |
175
|
|
|
|