1
|
|
|
import warnings |
2
|
|
|
from typing import Union, Tuple, Optional |
3
|
|
|
|
4
|
|
|
import numpy as np |
5
|
|
|
|
6
|
|
|
from .pad import Pad |
7
|
|
|
from .crop import Crop |
8
|
|
|
from .bounds_transform import BoundsTransform |
9
|
|
|
from ...transform import TypeTripletInt, TypeSixBounds |
10
|
|
|
from ....data.subject import Subject |
11
|
|
|
|
12
|
|
|
|
13
|
|
|
class CropOrPad(BoundsTransform): |
14
|
|
|
"""Crop and/or pad an image to a target shape. |
15
|
|
|
|
16
|
|
|
This transform modifies the affine matrix associated to the volume so that |
17
|
|
|
physical positions of the voxels are maintained. |
18
|
|
|
|
19
|
|
|
Args: |
20
|
|
|
target_shape: Tuple :math:`(W, H, D)`. If a single value :math:`N` is |
21
|
|
|
provided, then :math:`W = H = D = N`. |
22
|
|
|
padding_mode: Same as :attr:`padding_mode` in |
23
|
|
|
:class:`~torchio.transforms.Pad`. |
24
|
|
|
mask_name: If ``None``, the centers of the input and output volumes |
25
|
|
|
will be the same. |
26
|
|
|
If a string is given, the output volume center will be the center |
27
|
|
|
of the bounding box of non-zero values in the image named |
28
|
|
|
:attr:`mask_name`. |
29
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
30
|
|
|
keyword arguments. |
31
|
|
|
|
32
|
|
|
Example: |
33
|
|
|
>>> import torchio as tio |
34
|
|
|
>>> subject = tio.Subject( |
35
|
|
|
... chest_ct=tio.ScalarImage('subject_a_ct.nii.gz'), |
36
|
|
|
... heart_mask=tio.LabelMap('subject_a_heart_seg.nii.gz'), |
37
|
|
|
... ) |
38
|
|
|
>>> subject.chest_ct.shape |
39
|
|
|
torch.Size([1, 512, 512, 289]) |
40
|
|
|
>>> transform = tio.CropOrPad( |
41
|
|
|
... (120, 80, 180), |
42
|
|
|
... mask_name='heart_mask', |
43
|
|
|
... ) |
44
|
|
|
>>> transformed = transform(subject) |
45
|
|
|
>>> transformed.chest_ct.shape |
46
|
|
|
torch.Size([1, 120, 80, 180]) |
47
|
|
|
""" |
48
|
|
|
def __init__( |
49
|
|
|
self, |
50
|
|
|
target_shape: Union[int, TypeTripletInt], |
51
|
|
|
padding_mode: Union[str, float] = 0, |
52
|
|
|
mask_name: Optional[str] = None, |
53
|
|
|
**kwargs |
54
|
|
|
): |
55
|
|
|
super().__init__(target_shape, **kwargs) |
56
|
|
|
self.padding_mode = padding_mode |
57
|
|
|
if mask_name is not None and not isinstance(mask_name, str): |
58
|
|
|
message = ( |
59
|
|
|
'If mask_name is not None, it must be a string,' |
60
|
|
|
f' not {type(mask_name)}' |
61
|
|
|
) |
62
|
|
|
raise ValueError(message) |
63
|
|
|
self.mask_name = mask_name |
64
|
|
|
if self.mask_name is None: |
65
|
|
|
self.compute_crop_or_pad = self._compute_center_crop_or_pad |
66
|
|
|
else: |
67
|
|
|
if not isinstance(mask_name, str): |
68
|
|
|
message = ( |
69
|
|
|
'If mask_name is not None, it must be a string,' |
70
|
|
|
f' not {type(mask_name)}' |
71
|
|
|
) |
72
|
|
|
raise ValueError(message) |
73
|
|
|
self.compute_crop_or_pad = self._compute_mask_center_crop_or_pad |
74
|
|
|
|
75
|
|
|
@staticmethod |
76
|
|
|
def _bbox_mask(mask_volume: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
77
|
|
|
"""Return 6 coordinates of a 3D bounding box from a given mask. |
78
|
|
|
|
79
|
|
|
Taken from `this SO question <https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array>`_. |
80
|
|
|
|
81
|
|
|
Args: |
82
|
|
|
mask_volume: 3D NumPy array. |
83
|
|
|
""" # noqa: E501 |
84
|
|
|
i_any = np.any(mask_volume, axis=(1, 2)) |
85
|
|
|
j_any = np.any(mask_volume, axis=(0, 2)) |
86
|
|
|
k_any = np.any(mask_volume, axis=(0, 1)) |
87
|
|
|
i_min, i_max = np.where(i_any)[0][[0, -1]] |
88
|
|
|
j_min, j_max = np.where(j_any)[0][[0, -1]] |
89
|
|
|
k_min, k_max = np.where(k_any)[0][[0, -1]] |
90
|
|
|
bb_min = np.array([i_min, j_min, k_min]) |
91
|
|
|
bb_max = np.array([i_max, j_max, k_max]) + 1 |
92
|
|
|
return bb_min, bb_max |
93
|
|
|
|
94
|
|
|
@staticmethod |
95
|
|
|
def _get_six_bounds_parameters( |
96
|
|
|
parameters: np.ndarray, |
97
|
|
|
) -> TypeSixBounds: |
98
|
|
|
r"""Compute bounds parameters for ITK filters. |
99
|
|
|
|
100
|
|
|
Args: |
101
|
|
|
parameters: Tuple :math:`(w, h, d)` with the number of voxels to be |
102
|
|
|
cropped or padded. |
103
|
|
|
|
104
|
|
|
Returns: |
105
|
|
|
Tuple :math:`(w_{ini}, w_{fin}, h_{ini}, h_{fin}, d_{ini}, d_{fin})`, |
106
|
|
|
where :math:`n_{ini} = \left \lceil \frac{n}{2} \right \rceil` and |
107
|
|
|
:math:`n_{fin} = \left \lfloor \frac{n}{2} \right \rfloor`. |
108
|
|
|
|
109
|
|
|
Example: |
110
|
|
|
>>> p = np.array((4, 0, 7)) |
111
|
|
|
>>> CropOrPad._get_six_bounds_parameters(p) |
112
|
|
|
(2, 2, 0, 0, 4, 3) |
113
|
|
|
""" # noqa: E501 |
114
|
|
|
parameters = parameters / 2 |
115
|
|
|
result = [] |
116
|
|
|
for number in parameters: |
117
|
|
|
ini, fin = int(np.ceil(number)), int(np.floor(number)) |
118
|
|
|
result.extend([ini, fin]) |
119
|
|
|
return tuple(result) |
120
|
|
|
|
121
|
|
|
@property |
122
|
|
|
def target_shape(self): |
123
|
|
|
return self.bounds_parameters[::2] |
124
|
|
|
|
125
|
|
|
def _compute_cropping_padding_from_shapes( |
126
|
|
|
self, |
127
|
|
|
source_shape: TypeTripletInt, |
128
|
|
|
target_shape: TypeTripletInt, |
129
|
|
|
) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]: |
130
|
|
|
diff_shape = target_shape - source_shape |
131
|
|
|
|
132
|
|
|
cropping = -np.minimum(diff_shape, 0) |
133
|
|
|
if cropping.any(): |
134
|
|
|
cropping_params = self._get_six_bounds_parameters(cropping) |
135
|
|
|
else: |
136
|
|
|
cropping_params = None |
137
|
|
|
|
138
|
|
|
padding = np.maximum(diff_shape, 0) |
139
|
|
|
if padding.any(): |
140
|
|
|
padding_params = self._get_six_bounds_parameters(padding) |
141
|
|
|
else: |
142
|
|
|
padding_params = None |
143
|
|
|
|
144
|
|
|
return padding_params, cropping_params |
145
|
|
|
|
146
|
|
|
def _compute_center_crop_or_pad( |
147
|
|
|
self, |
148
|
|
|
subject: Subject, |
149
|
|
|
) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]: |
150
|
|
|
source_shape = subject.spatial_shape |
151
|
|
|
# The parent class turns the 3-element shape tuple (w, h, d) |
152
|
|
|
# into a 6-element bounds tuple (w, w, h, h, d, d) |
153
|
|
|
target_shape = np.array(self.bounds_parameters[::2]) |
154
|
|
|
parameters = self._compute_cropping_padding_from_shapes( |
155
|
|
|
source_shape, target_shape) |
156
|
|
|
padding_params, cropping_params = parameters |
157
|
|
|
return padding_params, cropping_params |
158
|
|
|
|
159
|
|
|
def _compute_mask_center_crop_or_pad( |
160
|
|
|
self, |
161
|
|
|
subject: Subject, |
162
|
|
|
) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]: |
163
|
|
|
if self.mask_name not in subject: |
164
|
|
|
message = ( |
165
|
|
|
f'Mask name "{self.mask_name}"' |
166
|
|
|
f' not found in subject keys "{tuple(subject.keys())}".' |
167
|
|
|
' Using volume center instead' |
168
|
|
|
) |
169
|
|
|
warnings.warn(message, RuntimeWarning) |
170
|
|
|
return self._compute_center_crop_or_pad(subject=subject) |
171
|
|
|
|
172
|
|
|
mask = subject[self.mask_name].numpy() |
173
|
|
|
|
174
|
|
|
if not np.any(mask): |
175
|
|
|
message = ( |
176
|
|
|
f'All values found in the mask "{self.mask_name}"' |
177
|
|
|
' are zero. Using volume center instead' |
178
|
|
|
) |
179
|
|
|
warnings.warn(message, RuntimeWarning) |
180
|
|
|
return self._compute_center_crop_or_pad(subject=subject) |
181
|
|
|
|
182
|
|
|
# Let's assume that the center of first voxel is at coordinate 0.5 |
183
|
|
|
# (which is typically not the case) |
184
|
|
|
subject_shape = subject.spatial_shape |
185
|
|
|
bb_min, bb_max = self._bbox_mask(mask[0]) |
186
|
|
|
center_mask = np.mean((bb_min, bb_max), axis=0) |
187
|
|
|
padding = [] |
188
|
|
|
cropping = [] |
189
|
|
|
target_shape = np.array(self.target_shape) |
190
|
|
|
|
191
|
|
|
for dim in range(3): |
192
|
|
|
target_dim = target_shape[dim] |
193
|
|
|
center_dim = center_mask[dim] |
194
|
|
|
subject_dim = subject_shape[dim] |
195
|
|
|
|
196
|
|
|
center_on_index = not (center_dim % 1) |
197
|
|
|
target_even = not (target_dim % 2) |
198
|
|
|
|
199
|
|
|
# Approximation when the center cannot be computed exactly |
200
|
|
|
# The output will be off by half a voxel, but this is just an |
201
|
|
|
# implementation detail |
202
|
|
|
if target_even ^ center_on_index: |
203
|
|
|
center_dim -= 0.5 |
204
|
|
|
|
205
|
|
|
begin = center_dim - target_dim / 2 |
206
|
|
|
if begin >= 0: |
207
|
|
|
crop_ini = begin |
208
|
|
|
pad_ini = 0 |
209
|
|
|
else: |
210
|
|
|
crop_ini = 0 |
211
|
|
|
pad_ini = -begin |
212
|
|
|
|
213
|
|
|
end = center_dim + target_dim / 2 |
214
|
|
|
if end <= subject_dim: |
215
|
|
|
crop_fin = subject_dim - end |
216
|
|
|
pad_fin = 0 |
217
|
|
|
else: |
218
|
|
|
crop_fin = 0 |
219
|
|
|
pad_fin = end - subject_dim |
220
|
|
|
|
221
|
|
|
padding.extend([pad_ini, pad_fin]) |
222
|
|
|
cropping.extend([crop_ini, crop_fin]) |
223
|
|
|
# Conversion for SimpleITK compatibility |
224
|
|
|
padding = np.asarray(padding, dtype=int) |
225
|
|
|
cropping = np.asarray(cropping, dtype=int) |
226
|
|
|
padding_params = tuple(padding.tolist()) if padding.any() else None |
227
|
|
|
cropping_params = tuple(cropping.tolist()) if cropping.any() else None |
228
|
|
|
return padding_params, cropping_params |
229
|
|
|
|
230
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
231
|
|
|
padding_params, cropping_params = self.compute_crop_or_pad(subject) |
232
|
|
|
padding_kwargs = {'padding_mode': self.padding_mode} |
233
|
|
|
if padding_params is not None: |
234
|
|
|
subject = Pad(padding_params, **padding_kwargs)(subject) |
235
|
|
|
if cropping_params is not None: |
236
|
|
|
subject = Crop(cropping_params)(subject) |
237
|
|
|
actual, target = subject.spatial_shape, self.target_shape |
238
|
|
|
assert actual == target, (actual, target) |
239
|
|
|
return subject |
240
|
|
|
|