1
|
|
|
import warnings |
2
|
|
|
from numbers import Number |
3
|
|
|
from typing import Union |
4
|
|
|
|
5
|
|
|
import numpy as np |
6
|
|
|
import torch |
7
|
|
|
from nibabel.affines import apply_affine |
8
|
|
|
|
9
|
|
|
from ....data.image import Image |
10
|
|
|
from ....data.subject import Subject |
11
|
|
|
from .bounds_transform import BoundsTransform |
12
|
|
|
from .bounds_transform import TypeBounds |
13
|
|
|
|
14
|
|
|
|
15
|
|
|
class Pad(BoundsTransform): |
16
|
|
|
r"""Pad an image. |
17
|
|
|
|
18
|
|
|
Args: |
19
|
|
|
padding: Tuple |
20
|
|
|
:math:`(w_{ini}, w_{fin}, h_{ini}, h_{fin}, d_{ini}, d_{fin})` |
21
|
|
|
defining the number of values padded to the edges of each axis. |
22
|
|
|
If the initial shape of the image is |
23
|
|
|
:math:`W \times H \times D`, the final shape will be |
24
|
|
|
:math:`(w_{ini} + W + w_{fin}) \times (h_{ini} + H + h_{fin}) |
25
|
|
|
\times (d_{ini} + D + d_{fin})`. |
26
|
|
|
If only three values :math:`(w, h, d)` are provided, then |
27
|
|
|
:math:`w_{ini} = w_{fin} = w`, |
28
|
|
|
:math:`h_{ini} = h_{fin} = h` and |
29
|
|
|
:math:`d_{ini} = d_{fin} = d`. |
30
|
|
|
If only one value :math:`n` is provided, then |
31
|
|
|
:math:`w_{ini} = w_{fin} = h_{ini} = h_{fin} = |
32
|
|
|
d_{ini} = d_{fin} = n`. |
33
|
|
|
padding_mode: See possible modes in `NumPy docs`_. If it is a number, |
34
|
|
|
the mode will be set to ``'constant'``. If it is ``'mean'``, |
35
|
|
|
``'maximum'``, ``'median'`` or ``'minimum'``, the statistic will be |
36
|
|
|
computed from the whole volume, unlike in NumPy, which computes it |
37
|
|
|
along the padded axis. |
38
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
39
|
|
|
keyword arguments. |
40
|
|
|
|
41
|
|
|
.. seealso:: If you want to pass the output shape instead, please use |
42
|
|
|
:class:`~torchio.transforms.CropOrPad` instead. |
43
|
|
|
|
44
|
|
|
.. _NumPy docs: https://numpy.org/doc/stable/reference/generated/numpy.pad.html |
45
|
|
|
""" |
46
|
|
|
|
47
|
|
|
PADDING_MODES = ( |
48
|
|
|
'empty', |
49
|
|
|
'edge', |
50
|
|
|
'wrap', |
51
|
|
|
'constant', |
52
|
|
|
'linear_ramp', |
53
|
|
|
'maximum', |
54
|
|
|
'mean', |
55
|
|
|
'median', |
56
|
|
|
'minimum', |
57
|
|
|
'reflect', |
58
|
|
|
'symmetric', |
59
|
|
|
) |
60
|
|
|
|
61
|
|
|
def __init__( |
62
|
|
|
self, |
63
|
|
|
padding: TypeBounds, |
64
|
|
|
padding_mode: Union[str, float] = 0, |
65
|
|
|
**kwargs, |
66
|
|
|
): |
67
|
|
|
super().__init__(padding, **kwargs) |
68
|
|
|
self.padding = padding |
69
|
|
|
self.check_padding_mode(padding_mode) |
70
|
|
|
self.padding_mode = padding_mode |
71
|
|
|
self.args_names = ['padding', 'padding_mode'] |
72
|
|
|
|
73
|
|
|
@classmethod |
74
|
|
|
def check_padding_mode(cls, padding_mode): |
75
|
|
|
is_number = isinstance(padding_mode, Number) |
76
|
|
|
is_callable = callable(padding_mode) |
77
|
|
|
if not (padding_mode in cls.PADDING_MODES or is_number or is_callable): |
78
|
|
|
message = ( |
79
|
|
|
f'Padding mode "{padding_mode}" not valid. Valid options are' |
80
|
|
|
f' {list(cls.PADDING_MODES)}, a number or a function' |
81
|
|
|
) |
82
|
|
|
raise KeyError(message) |
83
|
|
|
|
84
|
|
|
def _check_truncation(self, image: Image, mode: Union[str, float]) -> None: |
85
|
|
|
if mode not in ('mean', 'median'): |
86
|
|
|
return |
87
|
|
|
if torch.is_floating_point(image.data): |
88
|
|
|
return |
89
|
|
|
message = ( |
90
|
|
|
f'The constant value computed for padding mode "{mode}" might ' |
91
|
|
|
' be truncated in the output, as the input image is not' |
92
|
|
|
'floating point. Consider converting the image to a floating' |
93
|
|
|
' point type before applying this transform.' |
94
|
|
|
) |
95
|
|
|
warnings.warn(message, RuntimeWarning, stacklevel=2) |
96
|
|
|
|
97
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
98
|
|
|
assert self.bounds_parameters is not None |
99
|
|
|
low = self.bounds_parameters[::2] |
100
|
|
|
for image in self.get_images(subject): |
101
|
|
|
self._check_truncation(image, self.padding_mode) |
102
|
|
|
new_origin = apply_affine(image.affine, -np.array(low)) |
103
|
|
|
new_affine = image.affine.copy() |
104
|
|
|
new_affine[:3, 3] = new_origin |
105
|
|
|
|
106
|
|
|
mode: str | float = 'constant' |
107
|
|
|
constant: torch.Tensor | float | None = None |
108
|
|
|
kwargs: dict[str, str | float | torch.Tensor] = {} |
109
|
|
|
if isinstance(self.padding_mode, Number): |
110
|
|
|
constant = self.padding_mode # type: ignore[assignment] |
111
|
|
|
elif self.padding_mode == 'maximum': |
112
|
|
|
constant = image.data.max() |
113
|
|
|
elif self.padding_mode == 'mean': |
114
|
|
|
constant = image.data.float().mean() |
115
|
|
|
elif self.padding_mode == 'median': |
116
|
|
|
constant = torch.quantile(image.data.float(), 0.5) |
117
|
|
|
elif self.padding_mode == 'minimum': |
118
|
|
|
constant = image.data.min() |
119
|
|
|
else: |
120
|
|
|
constant = None |
121
|
|
|
mode = self.padding_mode |
122
|
|
|
|
123
|
|
|
if constant is not None: |
124
|
|
|
kwargs['constant_values'] = constant |
125
|
|
|
kwargs['mode'] = mode |
126
|
|
|
|
127
|
|
|
pad_params = self.bounds_parameters |
128
|
|
|
paddings = (0, 0), pad_params[:2], pad_params[2:4], pad_params[4:] |
129
|
|
|
padded = np.pad(image.data, paddings, **kwargs) # type: ignore[call-overload] |
130
|
|
|
image.set_data(torch.as_tensor(padded)) |
131
|
|
|
image.affine = new_affine |
132
|
|
|
return subject |
133
|
|
|
|
134
|
|
|
def inverse(self): |
135
|
|
|
from .crop import Crop |
136
|
|
|
|
137
|
|
|
return Crop(self.padding) |
138
|
|
|
|