Passed
Push — main ( 6ea3bf...edcd90 )
by Fernando
01:38
created

torchio.transforms.preprocessing.spatial.pad   A

Complexity

Total Complexity 17

Size/Duplication

Total Lines 138
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 85
dl 0
loc 138
rs 10
c 0
b 0
f 0
wmc 17

5 Methods

Rating   Name   Duplication   Size   Complexity  
A Pad.__init__() 0 11 1
A Pad.check_padding_mode() 0 10 4
A Pad.inverse() 0 4 1
B Pad.apply_transform() 0 36 8
A Pad._check_truncation() 0 12 3
1
import warnings
2
from numbers import Number
3
from typing import Union
4
5
import numpy as np
6
import torch
7
from nibabel.affines import apply_affine
8
9
from ....data.image import Image
10
from ....data.subject import Subject
11
from .bounds_transform import BoundsTransform
12
from .bounds_transform import TypeBounds
13
14
15
class Pad(BoundsTransform):
16
    r"""Pad an image.
17
18
    Args:
19
        padding: Tuple
20
            :math:`(w_{ini}, w_{fin}, h_{ini}, h_{fin}, d_{ini}, d_{fin})`
21
            defining the number of values padded to the edges of each axis.
22
            If the initial shape of the image is
23
            :math:`W \times H \times D`, the final shape will be
24
            :math:`(w_{ini} + W + w_{fin}) \times (h_{ini} + H + h_{fin})
25
            \times (d_{ini} + D + d_{fin})`.
26
            If only three values :math:`(w, h, d)` are provided, then
27
            :math:`w_{ini} = w_{fin} = w`,
28
            :math:`h_{ini} = h_{fin} = h` and
29
            :math:`d_{ini} = d_{fin} = d`.
30
            If only one value :math:`n` is provided, then
31
            :math:`w_{ini} = w_{fin} = h_{ini} = h_{fin} =
32
            d_{ini} = d_{fin} = n`.
33
        padding_mode: See possible modes in `NumPy docs`_. If it is a number,
34
            the mode will be set to ``'constant'``. If it is ``'mean'``,
35
            ``'maximum'``, ``'median'`` or ``'minimum'``, the statistic will be
36
            computed from the whole volume, unlike in NumPy, which computes it
37
            along the padded axis.
38
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
39
            keyword arguments.
40
41
    .. seealso:: If you want to pass the output shape instead, please use
42
        :class:`~torchio.transforms.CropOrPad` instead.
43
44
    .. _NumPy docs: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
45
    """
46
47
    PADDING_MODES = (
48
        'empty',
49
        'edge',
50
        'wrap',
51
        'constant',
52
        'linear_ramp',
53
        'maximum',
54
        'mean',
55
        'median',
56
        'minimum',
57
        'reflect',
58
        'symmetric',
59
    )
60
61
    def __init__(
62
        self,
63
        padding: TypeBounds,
64
        padding_mode: Union[str, float] = 0,
65
        **kwargs,
66
    ):
67
        super().__init__(padding, **kwargs)
68
        self.padding = padding
69
        self.check_padding_mode(padding_mode)
70
        self.padding_mode = padding_mode
71
        self.args_names = ['padding', 'padding_mode']
72
73
    @classmethod
74
    def check_padding_mode(cls, padding_mode):
75
        is_number = isinstance(padding_mode, Number)
76
        is_callable = callable(padding_mode)
77
        if not (padding_mode in cls.PADDING_MODES or is_number or is_callable):
78
            message = (
79
                f'Padding mode "{padding_mode}" not valid. Valid options are'
80
                f' {list(cls.PADDING_MODES)}, a number or a function'
81
            )
82
            raise KeyError(message)
83
84
    def _check_truncation(self, image: Image, mode: Union[str, float]) -> None:
85
        if mode not in ('mean', 'median'):
86
            return
87
        if torch.is_floating_point(image.data):
88
            return
89
        message = (
90
            f'The constant value computed for padding mode "{mode}" might '
91
            ' be truncated in the output, as the input image is not'
92
            'floating point. Consider converting the image to a floating'
93
            ' point type before applying this transform.'
94
        )
95
        warnings.warn(message, RuntimeWarning, stacklevel=2)
96
97
    def apply_transform(self, subject: Subject) -> Subject:
98
        assert self.bounds_parameters is not None
99
        low = self.bounds_parameters[::2]
100
        for image in self.get_images(subject):
101
            self._check_truncation(image, self.padding_mode)
102
            new_origin = apply_affine(image.affine, -np.array(low))
103
            new_affine = image.affine.copy()
104
            new_affine[:3, 3] = new_origin
105
106
            mode: str | float = 'constant'
107
            constant: torch.Tensor | float | None = None
108
            kwargs: dict[str, str | float | torch.Tensor] = {}
109
            if isinstance(self.padding_mode, Number):
110
                constant = self.padding_mode  # type: ignore[assignment]
111
            elif self.padding_mode == 'maximum':
112
                constant = image.data.max()
113
            elif self.padding_mode == 'mean':
114
                constant = image.data.float().mean()
115
            elif self.padding_mode == 'median':
116
                constant = torch.quantile(image.data.float(), 0.5)
117
            elif self.padding_mode == 'minimum':
118
                constant = image.data.min()
119
            else:
120
                constant = None
121
                mode = self.padding_mode
122
123
            if constant is not None:
124
                kwargs['constant_values'] = constant
125
            kwargs['mode'] = mode
126
127
            pad_params = self.bounds_parameters
128
            paddings = (0, 0), pad_params[:2], pad_params[2:4], pad_params[4:]
129
            padded = np.pad(image.data, paddings, **kwargs)  # type: ignore[call-overload]
130
            image.set_data(torch.as_tensor(padded))
131
            image.affine = new_affine
132
        return subject
133
134
    def inverse(self):
135
        from .crop import Crop
136
137
        return Crop(self.padding)
138