|
1
|
|
|
import warnings |
|
2
|
|
|
from typing import Union, Tuple, Optional |
|
3
|
|
|
import numpy as np |
|
4
|
|
|
from deprecated import deprecated |
|
5
|
|
|
from .pad import Pad |
|
6
|
|
|
from .crop import Crop |
|
7
|
|
|
from .bounds_transform import BoundsTransform, TypeShape, TypeSixBounds |
|
8
|
|
|
from ....torchio import DATA |
|
9
|
|
|
from ....data.subject import Subject |
|
10
|
|
|
from ....utils import round_up |
|
11
|
|
|
|
|
12
|
|
|
|
|
13
|
|
|
class CropOrPad(BoundsTransform): |
|
14
|
|
|
"""Crop and/or pad an image to a target shape. |
|
15
|
|
|
|
|
16
|
|
|
This transform modifies the affine matrix associated to the volume so that |
|
17
|
|
|
physical positions of the voxels are maintained. |
|
18
|
|
|
|
|
19
|
|
|
Args: |
|
20
|
|
|
target_shape: Tuple :math:`(D, H, W)`. If a single value :math:`N` is |
|
21
|
|
|
provided, then :math:`D = H = W = N`. |
|
22
|
|
|
padding_mode: See :py:class:`~torchio.transforms.Pad`. |
|
23
|
|
|
padding_fill: Same as :attr:`fill` in |
|
24
|
|
|
:py:class:`~torchio.transforms.Pad`. |
|
25
|
|
|
mask_name: If ``None``, the centers of the input and output volumes |
|
26
|
|
|
will be the same. |
|
27
|
|
|
If a string is given, the output volume center will be the center |
|
28
|
|
|
of the bounding box of non-zero values in the image named |
|
29
|
|
|
:py:attr:`mask_name`. |
|
30
|
|
|
p: Probability that this transform will be applied. |
|
31
|
|
|
|
|
32
|
|
|
Example: |
|
33
|
|
|
>>> import torchio |
|
34
|
|
|
>>> from torchio.tranforms import CropOrPad |
|
35
|
|
|
>>> subject = torchio.Subject( |
|
36
|
|
|
... torchio.Image('chest_ct', 'subject_a_ct.nii.gz', torchio.INTENSITY), |
|
37
|
|
|
... torchio.Image('heart_mask', 'subject_a_heart_seg.nii.gz', torchio.LABEL), |
|
38
|
|
|
... ) |
|
39
|
|
|
>>> sample = torchio.ImagesDataset([subject])[0] |
|
40
|
|
|
>>> sample['chest_ct'][torchio.DATA].shape |
|
41
|
|
|
torch.Size([1, 512, 512, 289]) |
|
42
|
|
|
>>> transform = CropOrPad( |
|
43
|
|
|
... (120, 80, 180), |
|
44
|
|
|
... mask_name='heart_mask', |
|
45
|
|
|
... ) |
|
46
|
|
|
>>> transformed = transform(sample) |
|
47
|
|
|
>>> transformed['chest_ct'][torchio.DATA].shape |
|
48
|
|
|
torch.Size([1, 120, 80, 180]) |
|
49
|
|
|
""" |
|
50
|
|
|
def __init__( |
|
51
|
|
|
self, |
|
52
|
|
|
target_shape: Union[int, TypeShape], |
|
53
|
|
|
padding_mode: str = 'constant', |
|
54
|
|
|
padding_fill: Optional[float] = None, |
|
55
|
|
|
mask_name: Optional[str] = None, |
|
56
|
|
|
p: float = 1, |
|
57
|
|
|
): |
|
58
|
|
|
super().__init__(target_shape, p=p) |
|
59
|
|
|
self.padding_mode = padding_mode |
|
60
|
|
|
self.padding_fill = padding_fill |
|
61
|
|
|
if mask_name is not None and not isinstance(mask_name, str): |
|
62
|
|
|
message = ( |
|
63
|
|
|
'If mask_name is not None, it must be a string,' |
|
64
|
|
|
f' not {type(mask_name)}' |
|
65
|
|
|
) |
|
66
|
|
|
raise ValueError(message) |
|
67
|
|
|
self.mask_name = mask_name |
|
68
|
|
|
if self.mask_name is None: |
|
69
|
|
|
self.compute_crop_or_pad = self._compute_center_crop_or_pad |
|
70
|
|
|
else: |
|
71
|
|
|
if not isinstance(mask_name, str): |
|
72
|
|
|
message = ( |
|
73
|
|
|
'If mask_name is not None, it must be a string,' |
|
74
|
|
|
f' not {type(mask_name)}' |
|
75
|
|
|
) |
|
76
|
|
|
raise ValueError(message) |
|
77
|
|
|
self.compute_crop_or_pad = self._compute_mask_center_crop_or_pad |
|
78
|
|
|
|
|
79
|
|
|
@staticmethod |
|
80
|
|
|
def _bbox_mask( |
|
81
|
|
|
mask_volume: np.ndarray, |
|
82
|
|
|
) -> Tuple[np.ndarray, np.ndarray]: |
|
83
|
|
|
"""Return 6 coordinates of a 3D bounding box from a given mask. |
|
84
|
|
|
|
|
85
|
|
|
Taken from `this SO question <https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array>`_. |
|
86
|
|
|
|
|
87
|
|
|
Args: |
|
88
|
|
|
mask_volume: 3D NumPy array. |
|
89
|
|
|
""" |
|
90
|
|
|
i_any = np.any(mask_volume, axis=(1, 2)) |
|
91
|
|
|
j_any = np.any(mask_volume, axis=(0, 2)) |
|
92
|
|
|
k_any = np.any(mask_volume, axis=(0, 1)) |
|
93
|
|
|
i_min, i_max = np.where(i_any)[0][[0, -1]] |
|
94
|
|
|
j_min, j_max = np.where(j_any)[0][[0, -1]] |
|
95
|
|
|
k_min, k_max = np.where(k_any)[0][[0, -1]] |
|
96
|
|
|
bb_min = np.array([i_min, j_min, k_min]) |
|
97
|
|
|
bb_max = np.array([i_max, j_max, k_max]) |
|
98
|
|
|
return bb_min, bb_max |
|
99
|
|
|
|
|
100
|
|
|
@staticmethod |
|
101
|
|
|
def _get_sample_shape(sample: Subject) -> TypeShape: |
|
102
|
|
|
"""Return the shape of the first image in the sample.""" |
|
103
|
|
|
sample.check_consistent_shape() |
|
104
|
|
|
for image_dict in sample.get_images(intensity_only=False): |
|
105
|
|
|
data = image_dict[DATA].shape[1:] # remove channels dimension |
|
106
|
|
|
break |
|
107
|
|
|
return data |
|
|
|
|
|
|
108
|
|
|
|
|
109
|
|
|
@staticmethod |
|
110
|
|
|
def _get_six_bounds_parameters( |
|
111
|
|
|
parameters: np.ndarray, |
|
112
|
|
|
) -> TypeSixBounds: |
|
113
|
|
|
r"""Compute bounds parameters for ITK filters. |
|
114
|
|
|
|
|
115
|
|
|
Args: |
|
116
|
|
|
parameters: Tuple :math:`(d, h, w)` with the number of voxels to be |
|
117
|
|
|
cropped or padded. |
|
118
|
|
|
|
|
119
|
|
|
Returns: |
|
120
|
|
|
Tuple :math:`(d_{ini}, d_{fin}, h_{ini}, h_{fin}, w_{ini}, w_{fin})`, |
|
121
|
|
|
where :math:`n_{ini} = \left \lceil \frac{n}{2} \right \rceil` and |
|
122
|
|
|
:math:`n_{fin} = \left \lfloor \frac{n}{2} \right \rfloor`. |
|
123
|
|
|
|
|
124
|
|
|
Example: |
|
125
|
|
|
>>> p = np.array((4, 0, 7)) |
|
126
|
|
|
>>> _get_six_bounds_parameters(p) |
|
127
|
|
|
(2, 2, 0, 0, 4, 3) |
|
128
|
|
|
|
|
129
|
|
|
""" |
|
130
|
|
|
parameters = parameters / 2 |
|
131
|
|
|
result = [] |
|
132
|
|
|
for number in parameters: |
|
133
|
|
|
ini, fin = int(np.ceil(number)), int(np.floor(number)) |
|
134
|
|
|
result.extend([ini, fin]) |
|
135
|
|
|
return tuple(result) |
|
136
|
|
|
|
|
137
|
|
|
def _compute_cropping_padding_from_shapes( |
|
138
|
|
|
self, |
|
139
|
|
|
source_shape: TypeShape, |
|
140
|
|
|
target_shape: TypeShape, |
|
141
|
|
|
) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]: |
|
142
|
|
|
diff_shape = target_shape - source_shape |
|
143
|
|
|
|
|
144
|
|
|
cropping = -np.minimum(diff_shape, 0) |
|
145
|
|
|
if cropping.any(): |
|
146
|
|
|
cropping_params = self._get_six_bounds_parameters(cropping) |
|
147
|
|
|
else: |
|
148
|
|
|
cropping_params = None |
|
149
|
|
|
|
|
150
|
|
|
padding = np.maximum(diff_shape, 0) |
|
151
|
|
|
if padding.any(): |
|
152
|
|
|
padding_params = self._get_six_bounds_parameters(padding) |
|
153
|
|
|
else: |
|
154
|
|
|
padding_params = None |
|
155
|
|
|
|
|
156
|
|
|
return padding_params, cropping_params |
|
157
|
|
|
|
|
158
|
|
|
def _compute_center_crop_or_pad( |
|
159
|
|
|
self, |
|
160
|
|
|
sample: Subject, |
|
161
|
|
|
) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]: |
|
162
|
|
|
source_shape = self._get_sample_shape(sample) |
|
163
|
|
|
# The parent class turns the 3-element shape tuple (d, h, w) |
|
164
|
|
|
# into a 6-element bounds tuple (d, d, h, h, w, w) |
|
165
|
|
|
target_shape = np.array(self.bounds_parameters[::2]) |
|
166
|
|
|
parameters = self._compute_cropping_padding_from_shapes( |
|
167
|
|
|
source_shape, target_shape) |
|
168
|
|
|
padding_params, cropping_params = parameters |
|
169
|
|
|
return padding_params, cropping_params |
|
170
|
|
|
|
|
171
|
|
|
def _compute_mask_center_crop_or_pad( |
|
172
|
|
|
self, |
|
173
|
|
|
sample: Subject, |
|
174
|
|
|
) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]: |
|
175
|
|
|
if self.mask_name not in sample: |
|
176
|
|
|
message = ( |
|
177
|
|
|
f'Mask name "{self.mask_name}"' |
|
178
|
|
|
f' not found in sample keys "{tuple(sample.keys())}".' |
|
179
|
|
|
' Using volume center instead' |
|
180
|
|
|
) |
|
181
|
|
|
warnings.warn(message) |
|
182
|
|
|
return self._compute_center_crop_or_pad(sample=sample) |
|
183
|
|
|
|
|
184
|
|
|
mask = sample[self.mask_name][DATA].numpy() |
|
185
|
|
|
|
|
186
|
|
|
if not np.any(mask): |
|
187
|
|
|
message = ( |
|
188
|
|
|
f'All values found in the mask "{self.mask_name}"' |
|
189
|
|
|
' are zero. Using volume center instead' |
|
190
|
|
|
) |
|
191
|
|
|
warnings.warn(message) |
|
192
|
|
|
return self._compute_center_crop_or_pad(sample=sample) |
|
193
|
|
|
|
|
194
|
|
|
# Original sample shape (from mask shape) |
|
195
|
|
|
sample_shape = self._get_sample_shape(sample) # remove channels dimension |
|
196
|
|
|
# Calculate bounding box of the mask center |
|
197
|
|
|
bb_min, bb_max = self._bbox_mask(mask[0]) |
|
198
|
|
|
# Coordinates of the mask center |
|
199
|
|
|
center_mask = (bb_max - bb_min) / 2 + bb_min |
|
200
|
|
|
# List of padding to do |
|
201
|
|
|
padding = [] |
|
202
|
|
|
# Final cropping (after padding) |
|
203
|
|
|
cropping = [] |
|
204
|
|
|
for dim, center_dimension in enumerate(center_mask): |
|
205
|
|
|
# Compute coordinates of the target shape taken from the center of |
|
206
|
|
|
# the mask |
|
207
|
|
|
center_dim = round_up(center_dimension) |
|
208
|
|
|
begin = center_dim - (self.bounds_parameters[2 * dim] / 2) |
|
209
|
|
|
end = center_dim + (self.bounds_parameters[2 * dim + 1] / 2) |
|
210
|
|
|
# Check if dimension needs padding (before or after) |
|
211
|
|
|
begin_pad = round_up(abs(min(begin, 0))) |
|
212
|
|
|
end_pad = round(max(end - sample_shape[dim], 0)) |
|
213
|
|
|
# Check if cropping is needed |
|
214
|
|
|
begin_crop = round_up(max(begin, 0)) |
|
215
|
|
|
end_crop = abs(round(min(end - sample_shape[dim], 0))) |
|
216
|
|
|
# Add padding values of the dim to the list |
|
217
|
|
|
padding.append(begin_pad) |
|
218
|
|
|
padding.append(end_pad) |
|
219
|
|
|
# Add the slice of the dimension to take |
|
220
|
|
|
cropping.append(begin_crop) |
|
221
|
|
|
cropping.append(end_crop) |
|
222
|
|
|
# Conversion for SimpleITK compatibility |
|
223
|
|
|
padding = np.asarray(padding, dtype=int) |
|
224
|
|
|
cropping = np.asarray(cropping, dtype=int) |
|
225
|
|
|
padding_params = tuple(padding.tolist()) if padding.any() else None |
|
226
|
|
|
cropping_params = tuple(cropping.tolist()) if cropping.any() else None |
|
227
|
|
|
return padding_params, cropping_params |
|
228
|
|
|
|
|
229
|
|
|
def apply_transform(self, sample: Subject) -> dict: |
|
230
|
|
|
padding_params, cropping_params = self.compute_crop_or_pad(sample) |
|
231
|
|
|
padding_kwargs = dict( |
|
232
|
|
|
padding_mode=self.padding_mode, fill=self.padding_fill) |
|
233
|
|
|
if padding_params is not None: |
|
234
|
|
|
sample = Pad(padding_params, **padding_kwargs)(sample) |
|
235
|
|
|
if cropping_params is not None: |
|
236
|
|
|
sample = Crop(cropping_params)(sample) |
|
237
|
|
|
return sample |
|
238
|
|
|
|
|
239
|
|
|
|
|
240
|
|
|
@deprecated('CenterCropOrPad is deprecated. Use CropOrPad instead.') |
|
241
|
|
|
class CenterCropOrPad(CropOrPad): |
|
242
|
|
|
"""Crop or pad around image center.""" |
|
243
|
|
|
|