|
1
|
|
|
from typing import Union, Tuple, Optional |
|
|
|
|
|
|
2
|
|
|
import numpy as np |
|
|
|
|
|
|
3
|
|
|
from deprecated import deprecated |
|
|
|
|
|
|
4
|
|
|
from .pad import Pad |
|
5
|
|
|
from .crop import Crop |
|
6
|
|
|
from .bounds_transform import BoundsTransform |
|
7
|
|
|
from ....torchio import DATA |
|
8
|
|
|
from ....utils import is_image_dict, check_consistent_shape |
|
9
|
|
|
|
|
10
|
|
|
|
|
11
|
|
|
class CropOrPad(BoundsTransform): |
|
|
|
|
|
|
12
|
|
|
"""Crop and/or pad an image to a target shape. |
|
13
|
|
|
|
|
14
|
|
|
This transform modifies the affine matrix associated to the volume so that |
|
15
|
|
|
physical positions of the voxels are maintained. |
|
16
|
|
|
|
|
17
|
|
|
Args: |
|
18
|
|
|
target_shape: Tuple :math:`(D, H, W)`. If a single value :math:`N` is |
|
19
|
|
|
provided, then :math:`D = H = W = N`. |
|
20
|
|
|
padding_mode: See :py:class:`~torchio.transforms.Pad`. |
|
21
|
|
|
padding_fill: Same as :attr:`fill` in |
|
22
|
|
|
:py:class:`~torchio.transforms.Pad`. |
|
23
|
|
|
mode: Whether to crop/pad using the image center or the center of the |
|
24
|
|
|
bounding box with non-zero values of a given mask with name |
|
25
|
|
|
:py:attr:`mask_name`. |
|
26
|
|
|
Possible values are ``'center'`` or ``'mask'``. |
|
27
|
|
|
mask_name: If :py:attr:`mode` is ``'mask'``, name of the mask from |
|
28
|
|
|
which to extract the bounding box. |
|
29
|
|
|
|
|
30
|
|
|
Example: |
|
31
|
|
|
>>> import torchio |
|
32
|
|
|
>>> from torchio.tranforms import CropOrPad |
|
33
|
|
|
>>> subject = torchio.Subject( |
|
34
|
|
|
... torchio.Image('chest_ct', 'subject_a_ct.nii.gz', torchio.INTENSITY), |
|
35
|
|
|
... torchio.Image('heart_mask', 'subject_a_heart_seg.nii.gz', torchio.LABEL), |
|
36
|
|
|
... ) |
|
37
|
|
|
>>> sample = torchio.ImagesDataset([subject])[0] |
|
38
|
|
|
>>> sample['chest_ct'][torchio.DATA].shape |
|
39
|
|
|
torch.Size([1, 512, 512, 289]) |
|
40
|
|
|
>>> transform = CropOrPad( |
|
41
|
|
|
... (120, 80, 180), |
|
42
|
|
|
... padding_mode='reflect', |
|
43
|
|
|
... mode='mask', |
|
44
|
|
|
... mask_name='heart_mask', |
|
45
|
|
|
... ) |
|
46
|
|
|
>>> transformed = transform(sample) |
|
47
|
|
|
>>> transformed['chest_ct'][torchio.DATA].shape |
|
48
|
|
|
torch.Size([1, 120, 80, 180]) |
|
49
|
|
|
""" |
|
50
|
|
|
def __init__( |
|
|
|
|
|
|
51
|
|
|
self, |
|
52
|
|
|
target_shape: Union[int, Tuple[int, int, int]], |
|
53
|
|
|
padding_mode: str = 'constant', |
|
54
|
|
|
padding_fill: Optional[float] = None, |
|
55
|
|
|
mode: str = 'center', |
|
56
|
|
|
mask_name: Optional[str] = None, |
|
57
|
|
|
): |
|
|
|
|
|
|
58
|
|
|
super().__init__(target_shape) |
|
59
|
|
|
self.mode = mode |
|
60
|
|
|
self.padding_mode = padding_mode |
|
61
|
|
|
self.padding_fill = padding_fill |
|
62
|
|
|
if mode not in {'center', 'mask'}: |
|
63
|
|
|
message = f'Mode must be "center" or "mask", not "{mode}"' |
|
64
|
|
|
raise ValueError(message) |
|
65
|
|
|
if mode == 'mask': |
|
66
|
|
|
if mask_name is None: |
|
67
|
|
|
message = 'If mode is "mask", mask_name cannot be None' |
|
68
|
|
|
raise ValueError(message) |
|
69
|
|
|
self.mask_name = mask_name |
|
70
|
|
|
self.compute_crop_or_pad = self._compute_mask_center_crop_or_pad |
|
71
|
|
|
else: |
|
72
|
|
|
self.compute_crop_or_pad = self._compute_center_crop_or_pad |
|
73
|
|
|
|
|
74
|
|
|
@staticmethod |
|
75
|
|
|
def _bbox_mask(mask_volume: np.ndarray): |
|
76
|
|
|
"""Return 6 coordinates of a 3D bounding box from a given mask. |
|
77
|
|
|
|
|
78
|
|
|
Taken from `this SO question <https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array>`_. |
|
|
|
|
|
|
79
|
|
|
|
|
80
|
|
|
Args: |
|
81
|
|
|
mask_volume: 3D NumPy array. |
|
82
|
|
|
""" |
|
83
|
|
|
r = np.any(mask_volume, axis=(1, 2)) |
|
|
|
|
|
|
84
|
|
|
c = np.any(mask_volume, axis=(0, 2)) |
|
|
|
|
|
|
85
|
|
|
z = np.any(mask_volume, axis=(0, 1)) |
|
|
|
|
|
|
86
|
|
|
rmin, rmax = np.where(r)[0][[0, -1]] |
|
87
|
|
|
cmin, cmax = np.where(c)[0][[0, -1]] |
|
88
|
|
|
zmin, zmax = np.where(z)[0][[0, -1]] |
|
89
|
|
|
return np.array([rmin, cmin, zmin]), np.array([rmax, cmax, zmax]) |
|
90
|
|
|
|
|
91
|
|
|
@staticmethod |
|
92
|
|
|
def _get_sample_shape(sample: dict) -> Tuple[int]: |
|
93
|
|
|
"""Return the shape of the first image in the sample.""" |
|
94
|
|
|
check_consistent_shape(sample) |
|
95
|
|
|
for image_dict in sample.values(): |
|
96
|
|
|
if not is_image_dict(image_dict): |
|
97
|
|
|
continue |
|
98
|
|
|
data = image_dict[DATA].shape[1:] # remove channels dimension |
|
99
|
|
|
break |
|
100
|
|
|
return data |
|
|
|
|
|
|
101
|
|
|
|
|
102
|
|
|
@staticmethod |
|
103
|
|
|
def _get_six_bounds_parameters(parameters: np.ndarray): |
|
104
|
|
|
r"""Compute bounds parameters for ITK filters. |
|
105
|
|
|
|
|
106
|
|
|
Args: |
|
107
|
|
|
parameters: Tuple :math:`(d, h, w)` with the number of voxels to be |
|
108
|
|
|
cropped or padded. |
|
109
|
|
|
|
|
110
|
|
|
Returns: |
|
111
|
|
|
Tuple :math:`(d_{ini}, d_{fin}, h_{ini}, h_{fin}, w_{ini}, w_{fin})`, |
|
112
|
|
|
where :math:`n_{ini} = \left \lceil \frac{n}{2} \right \rceil` and |
|
113
|
|
|
:math:`n_{fin} = \left \lfloor \frac{n}{2} \right \rfloor`. |
|
114
|
|
|
|
|
115
|
|
|
Example: |
|
116
|
|
|
>>> p = np.array((4, 0, 7)) |
|
117
|
|
|
>>> _get_six_bounds_parameters(p) |
|
118
|
|
|
(2, 2, 0, 0, 4, 3) |
|
119
|
|
|
|
|
120
|
|
|
""" |
|
121
|
|
|
parameters = parameters / 2 |
|
122
|
|
|
result = [] |
|
123
|
|
|
for n in parameters: |
|
|
|
|
|
|
124
|
|
|
ini, fin = int(np.ceil(n)), int(np.floor(n)) |
|
125
|
|
|
result.extend([ini, fin]) |
|
126
|
|
|
return tuple(result) |
|
127
|
|
|
|
|
128
|
|
|
def _compute_center_crop_or_pad(self, sample: dict): |
|
129
|
|
|
source_shape = self._get_sample_shape(sample) |
|
130
|
|
|
# The parent class turns the 3-element shape tuple (d, h, w) |
|
131
|
|
|
# into a 6-element bounds tuple (d, d, h, h, w, w) |
|
132
|
|
|
target_shape = np.array(self.bounds_parameters[::2]) |
|
133
|
|
|
diff_shape = target_shape - source_shape |
|
134
|
|
|
|
|
135
|
|
|
cropping = -np.minimum(diff_shape, 0) |
|
136
|
|
|
if cropping.any(): |
|
137
|
|
|
cropping_params = self._get_six_bounds_parameters(cropping) |
|
138
|
|
|
else: |
|
139
|
|
|
cropping_params = None |
|
140
|
|
|
|
|
141
|
|
|
padding = np.maximum(diff_shape, 0) |
|
142
|
|
|
if padding.any(): |
|
143
|
|
|
padding_params = self._get_six_bounds_parameters(padding) |
|
144
|
|
|
else: |
|
145
|
|
|
padding_params = None |
|
146
|
|
|
|
|
147
|
|
|
return padding_params, cropping_params |
|
148
|
|
|
|
|
149
|
|
|
def _compute_mask_center_crop_or_pad(self, sample: dict): |
|
|
|
|
|
|
150
|
|
|
if self.mask_name not in sample: |
|
151
|
|
|
message = ( |
|
152
|
|
|
f'Mask name "{self.mask_name}"' |
|
153
|
|
|
f' not found in sample keys: {tuple(sample.keys())}' |
|
154
|
|
|
) |
|
155
|
|
|
raise KeyError(message) |
|
156
|
|
|
mask = sample[self.mask_name][DATA].numpy() |
|
157
|
|
|
# Original sample shape (from mask shape) |
|
158
|
|
|
sample_shape = np.squeeze(mask).shape |
|
159
|
|
|
# Calculate bounding box of the mask center |
|
160
|
|
|
bb_min, bb_max = self._bbox_mask(mask[0]) |
|
161
|
|
|
# Coordinates of the mask center |
|
162
|
|
|
center_mask = (bb_max - bb_min) / 2 + bb_min |
|
163
|
|
|
# List of padding to do |
|
164
|
|
|
padding = [] |
|
165
|
|
|
# Final cropping (after padding) |
|
166
|
|
|
cropping = [] |
|
167
|
|
|
for dim, center_dim in enumerate(center_mask): |
|
168
|
|
|
# Compute coordinates of the target shape taken from the center of |
|
169
|
|
|
# the mask |
|
170
|
|
|
begin = center_dim - (self.bounds_parameters[2 * dim] / 2) |
|
171
|
|
|
end = center_dim + (self.bounds_parameters[2 * dim + 1] / 2) |
|
172
|
|
|
# Check if dimension needs padding (before or after) |
|
173
|
|
|
begin_pad = round(abs(min(begin, 0))) |
|
174
|
|
|
end_pad = round(max(end - sample_shape[dim], 0)) |
|
175
|
|
|
# Check if cropping is needed |
|
176
|
|
|
begin_crop = round(max(begin, 0)) |
|
177
|
|
|
end_crop = abs(round(min(end - sample_shape[dim], 0))) |
|
178
|
|
|
# Add padding values of the dim to the list |
|
179
|
|
|
padding.append(begin_pad) |
|
180
|
|
|
padding.append(end_pad) |
|
181
|
|
|
# Add the slice of the dimension to take |
|
182
|
|
|
cropping.append(begin_crop) |
|
183
|
|
|
cropping.append(end_crop) |
|
184
|
|
|
# Conversion for SimpleITK compatibility |
|
185
|
|
|
padding_params = np.asarray(padding, dtype=np.uint).tolist() |
|
186
|
|
|
cropping_params = np.asarray(cropping, dtype=np.uint).tolist() |
|
187
|
|
|
return padding_params, cropping_params |
|
188
|
|
|
|
|
189
|
|
|
def apply_transform(self, sample: dict) -> dict: |
|
190
|
|
|
padding_params, cropping_params = self.compute_crop_or_pad(sample) |
|
191
|
|
|
padding_kwargs = dict( |
|
192
|
|
|
padding_mode=self.padding_mode, fill=self.padding_fill) |
|
193
|
|
|
if padding_params is not None: |
|
194
|
|
|
sample = Pad(padding_params, **padding_kwargs)(sample) |
|
195
|
|
|
if cropping_params is not None: |
|
196
|
|
|
sample = Crop(cropping_params)(sample) |
|
197
|
|
|
return sample |
|
198
|
|
|
|
|
199
|
|
|
|
|
200
|
|
|
@deprecated('CenterCropOrPad is deprecated. Use CropOrPad instead.') |
|
|
|
|
|
|
201
|
|
|
class CenterCropOrPad(CropOrPad): |
|
202
|
|
|
"""Crop and/or pad an image to a target shape. |
|
203
|
|
|
Args: |
|
204
|
|
|
target_shape: Tuple :math:`(D, H, W)`. If a single value :math:`N` is |
|
205
|
|
|
provided, then :math:`D = H = W = N`. |
|
206
|
|
|
padding_mode: See :py:class:`~torchio.transforms.Pad`. |
|
207
|
|
|
padding_fill: Same as :attr:`fill` in |
|
208
|
|
|
:py:class:`~torchio.transforms.Pad`. |
|
209
|
|
|
""" |
|
210
|
|
|
|
|
211
|
|
|
def __init__( |
|
212
|
|
|
self, |
|
213
|
|
|
target_shape: Union[int, Tuple[int, int, int]], |
|
214
|
|
|
padding_mode: str = 'constant', |
|
215
|
|
|
padding_fill: Optional[float] = None, |
|
216
|
|
|
): |
|
|
|
|
|
|
217
|
|
|
super().__init__( |
|
218
|
|
|
target_shape=target_shape, |
|
219
|
|
|
padding_mode=padding_mode, |
|
220
|
|
|
padding_fill=padding_fill, |
|
221
|
|
|
mode='center', |
|
222
|
|
|
) |
|
223
|
|
|
|