Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

torchio.transforms.preprocessing.spatial.crop_or_pad   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 243
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 133
dl 0
loc 243
rs 10
c 0
b 0
f 0
wmc 23

8 Methods

Rating   Name   Duplication   Size   Complexity  
A CropOrPad._bbox_mask() 0 20 1
B CropOrPad.__init__() 0 28 5
A CropOrPad._get_sample_shape() 0 8 2
B CropOrPad._compute_mask_center_crop_or_pad() 0 57 6
A CropOrPad._compute_center_crop_or_pad() 0 12 1
A CropOrPad._compute_cropping_padding_from_shapes() 0 20 3
A CropOrPad.apply_transform() 0 9 3
A CropOrPad._get_six_bounds_parameters() 0 27 2
1
import warnings
2
from typing import Union, Tuple, Optional
3
import numpy as np
4
from deprecated import deprecated
5
from .pad import Pad
6
from .crop import Crop
7
from .bounds_transform import BoundsTransform, TypeShape, TypeSixBounds
8
from ....torchio import DATA
9
from ....data.subject import Subject
10
from ....utils import round_up
11
12
13
class CropOrPad(BoundsTransform):
14
    """Crop and/or pad an image to a target shape.
15
16
    This transform modifies the affine matrix associated to the volume so that
17
    physical positions of the voxels are maintained.
18
19
    Args:
20
        target_shape: Tuple :math:`(D, H, W)`. If a single value :math:`N` is
21
            provided, then :math:`D = H = W = N`.
22
        padding_mode: See :py:class:`~torchio.transforms.Pad`.
23
        padding_fill: Same as :attr:`fill` in
24
            :py:class:`~torchio.transforms.Pad`.
25
        mask_name: If ``None``, the centers of the input and output volumes
26
            will be the same.
27
            If a string is given, the output volume center will be the center
28
            of the bounding box of non-zero values in the image named
29
            :py:attr:`mask_name`.
30
        p: Probability that this transform will be applied.
31
32
    Example:
33
        >>> import torchio
34
        >>> from torchio.tranforms import CropOrPad
35
        >>> subject = torchio.Subject(
36
        ...     torchio.Image('chest_ct', 'subject_a_ct.nii.gz', torchio.INTENSITY),
37
        ...     torchio.Image('heart_mask', 'subject_a_heart_seg.nii.gz', torchio.LABEL),
38
        ... )
39
        >>> sample = torchio.ImagesDataset([subject])[0]
40
        >>> sample['chest_ct'][torchio.DATA].shape
41
        torch.Size([1, 512, 512, 289])
42
        >>> transform = CropOrPad(
43
        ...     (120, 80, 180),
44
        ...     mask_name='heart_mask',
45
        ... )
46
        >>> transformed = transform(sample)
47
        >>> transformed['chest_ct'][torchio.DATA].shape
48
        torch.Size([1, 120, 80, 180])
49
    """
50
    def __init__(
51
            self,
52
            target_shape: Union[int, TypeShape],
53
            padding_mode: str = 'constant',
54
            padding_fill: Optional[float] = None,
55
            mask_name: Optional[str] = None,
56
            p: float = 1,
57
            ):
58
        super().__init__(target_shape, p=p)
59
        self.padding_mode = padding_mode
60
        self.padding_fill = padding_fill
61
        if mask_name is not None and not isinstance(mask_name, str):
62
            message = (
63
                'If mask_name is not None, it must be a string,'
64
                f' not {type(mask_name)}'
65
            )
66
            raise ValueError(message)
67
        self.mask_name = mask_name
68
        if self.mask_name is None:
69
            self.compute_crop_or_pad = self._compute_center_crop_or_pad
70
        else:
71
            if not isinstance(mask_name, str):
72
                message = (
73
                    'If mask_name is not None, it must be a string,'
74
                    f' not {type(mask_name)}'
75
                )
76
                raise ValueError(message)
77
            self.compute_crop_or_pad = self._compute_mask_center_crop_or_pad
78
79
    @staticmethod
80
    def _bbox_mask(
81
            mask_volume: np.ndarray,
82
            ) -> Tuple[np.ndarray, np.ndarray]:
83
        """Return 6 coordinates of a 3D bounding box from a given mask.
84
85
        Taken from `this SO question <https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array>`_.
86
87
        Args:
88
            mask_volume: 3D NumPy array.
89
        """
90
        i_any = np.any(mask_volume, axis=(1, 2))
91
        j_any = np.any(mask_volume, axis=(0, 2))
92
        k_any = np.any(mask_volume, axis=(0, 1))
93
        i_min, i_max = np.where(i_any)[0][[0, -1]]
94
        j_min, j_max = np.where(j_any)[0][[0, -1]]
95
        k_min, k_max = np.where(k_any)[0][[0, -1]]
96
        bb_min = np.array([i_min, j_min, k_min])
97
        bb_max = np.array([i_max, j_max, k_max])
98
        return bb_min, bb_max
99
100
    @staticmethod
101
    def _get_sample_shape(sample: Subject) -> TypeShape:
102
        """Return the shape of the first image in the sample."""
103
        sample.check_consistent_shape()
104
        for image_dict in sample.get_images(intensity_only=False):
105
            data = image_dict[DATA].shape[1:]  # remove channels dimension
106
            break
107
        return data
0 ignored issues
show
introduced by
The variable data does not seem to be defined for all execution paths.
Loading history...
108
109
    @staticmethod
110
    def _get_six_bounds_parameters(
111
            parameters: np.ndarray,
112
            ) -> TypeSixBounds:
113
        r"""Compute bounds parameters for ITK filters.
114
115
        Args:
116
            parameters: Tuple :math:`(d, h, w)` with the number of voxels to be
117
                cropped or padded.
118
119
        Returns:
120
            Tuple :math:`(d_{ini}, d_{fin}, h_{ini}, h_{fin}, w_{ini}, w_{fin})`,
121
            where :math:`n_{ini} = \left \lceil \frac{n}{2} \right \rceil` and
122
            :math:`n_{fin} = \left \lfloor \frac{n}{2} \right \rfloor`.
123
124
        Example:
125
            >>> p = np.array((4, 0, 7))
126
            >>> _get_six_bounds_parameters(p)
127
            (2, 2, 0, 0, 4, 3)
128
129
        """
130
        parameters = parameters / 2
131
        result = []
132
        for number in parameters:
133
            ini, fin = int(np.ceil(number)), int(np.floor(number))
134
            result.extend([ini, fin])
135
        return tuple(result)
136
137
    def _compute_cropping_padding_from_shapes(
138
            self,
139
            source_shape: TypeShape,
140
            target_shape: TypeShape,
141
            ) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]:
142
        diff_shape = target_shape - source_shape
143
144
        cropping = -np.minimum(diff_shape, 0)
145
        if cropping.any():
146
            cropping_params = self._get_six_bounds_parameters(cropping)
147
        else:
148
            cropping_params = None
149
150
        padding = np.maximum(diff_shape, 0)
151
        if padding.any():
152
            padding_params = self._get_six_bounds_parameters(padding)
153
        else:
154
            padding_params = None
155
156
        return padding_params, cropping_params
157
158
    def _compute_center_crop_or_pad(
159
            self,
160
            sample: Subject,
161
            ) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]:
162
        source_shape = self._get_sample_shape(sample)
163
        # The parent class turns the 3-element shape tuple (d, h, w)
164
        # into a 6-element bounds tuple (d, d, h, h, w, w)
165
        target_shape = np.array(self.bounds_parameters[::2])
166
        parameters = self._compute_cropping_padding_from_shapes(
167
            source_shape, target_shape)
168
        padding_params, cropping_params = parameters
169
        return padding_params, cropping_params
170
171
    def _compute_mask_center_crop_or_pad(
172
            self,
173
            sample: Subject,
174
            ) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]:
175
        if self.mask_name not in sample:
176
            message = (
177
                f'Mask name "{self.mask_name}"'
178
                f' not found in sample keys "{tuple(sample.keys())}".'
179
                ' Using volume center instead'
180
            )
181
            warnings.warn(message)
182
            return self._compute_center_crop_or_pad(sample=sample)
183
184
        mask = sample[self.mask_name][DATA].numpy()
185
186
        if not np.any(mask):
187
            message = (
188
                f'All values found in the mask "{self.mask_name}"'
189
                ' are zero. Using volume center instead'
190
            )
191
            warnings.warn(message)
192
            return self._compute_center_crop_or_pad(sample=sample)
193
194
        # Original sample shape (from mask shape)
195
        sample_shape = self._get_sample_shape(sample)  # remove channels dimension
196
        # Calculate bounding box of the mask center
197
        bb_min, bb_max = self._bbox_mask(mask[0])
198
        # Coordinates of the mask center
199
        center_mask = (bb_max - bb_min) / 2 + bb_min
200
        # List of padding to do
201
        padding = []
202
        # Final cropping (after padding)
203
        cropping = []
204
        for dim, center_dimension in enumerate(center_mask):
205
            # Compute coordinates of the target shape taken from the center of
206
            # the mask
207
            center_dim = round_up(center_dimension)
208
            begin = center_dim - (self.bounds_parameters[2 * dim] / 2)
209
            end = center_dim + (self.bounds_parameters[2 * dim + 1] / 2)
210
            # Check if dimension needs padding (before or after)
211
            begin_pad = round_up(abs(min(begin, 0)))
212
            end_pad = round(max(end - sample_shape[dim], 0))
213
            # Check if cropping is needed
214
            begin_crop = round_up(max(begin, 0))
215
            end_crop = abs(round(min(end - sample_shape[dim], 0)))
216
            # Add padding values of the dim to the list
217
            padding.append(begin_pad)
218
            padding.append(end_pad)
219
            # Add the slice of the dimension to take
220
            cropping.append(begin_crop)
221
            cropping.append(end_crop)
222
        # Conversion for SimpleITK compatibility
223
        padding = np.asarray(padding, dtype=int)
224
        cropping = np.asarray(cropping, dtype=int)
225
        padding_params = tuple(padding.tolist()) if padding.any() else None
226
        cropping_params = tuple(cropping.tolist()) if cropping.any() else None
227
        return padding_params, cropping_params
228
229
    def apply_transform(self, sample: Subject) -> dict:
230
        padding_params, cropping_params = self.compute_crop_or_pad(sample)
231
        padding_kwargs = dict(
232
            padding_mode=self.padding_mode, fill=self.padding_fill)
233
        if padding_params is not None:
234
            sample = Pad(padding_params, **padding_kwargs)(sample)
235
        if cropping_params is not None:
236
            sample = Crop(cropping_params)(sample)
237
        return sample
238
239
240
@deprecated('CenterCropOrPad is deprecated. Use CropOrPad instead.')
241
class CenterCropOrPad(CropOrPad):
242
    """Crop or pad around image center."""
243