Passed
Pull Request — master (#110)
by Fernando
01:29
created

torchio.transforms.preprocessing.spatial.crop_or_pad   A

Complexity

Total Complexity 20

Size/Duplication

Total Lines 222
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 117
dl 0
loc 222
rs 10
c 0
b 0
f 0
wmc 20

8 Methods

Rating   Name   Duplication   Size   Complexity  
A CropOrPad._get_sample_shape() 0 10 3
A CropOrPad._bbox_mask() 0 16 1
A CropOrPad._compute_mask_center_crop_or_pad() 0 39 3
A CenterCropOrPad.__init__() 0 11 1
A CropOrPad.__init__() 0 23 4
A CropOrPad._compute_center_crop_or_pad() 0 20 3
A CropOrPad.apply_transform() 0 9 3
A CropOrPad._get_six_bounds_parameters() 0 25 2
1
from typing import Union, Tuple, Optional
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
import numpy as np
0 ignored issues
show
introduced by
Unable to import 'numpy'
Loading history...
3
from deprecated import deprecated
0 ignored issues
show
introduced by
Unable to import 'deprecated'
Loading history...
4
from .pad import Pad
5
from .crop import Crop
6
from .bounds_transform import BoundsTransform
7
from ....torchio import DATA
8
from ....utils import is_image_dict, check_consistent_shape
9
10
11
class CropOrPad(BoundsTransform):
0 ignored issues
show
Bug introduced by
The method bounds_function which was declared abstract in the super-class BoundsTransform
was not overridden.

Methods which raise NotImplementedError should be overridden in concrete child classes.

Loading history...
12
    """Crop and/or pad an image to a target shape.
13
14
    This transform modifies the affine matrix associated to the volume so that
15
    physical positions of the voxels are maintained.
16
17
    Args:
18
        target_shape: Tuple :math:`(D, H, W)`. If a single value :math:`N` is
19
            provided, then :math:`D = H = W = N`.
20
        padding_mode: See :py:class:`~torchio.transforms.Pad`.
21
        padding_fill: Same as :attr:`fill` in
22
            :py:class:`~torchio.transforms.Pad`.
23
        mode: Whether to crop/pad using the image center or the center of the
24
            bounding box with non-zero values of a given mask with name
25
            :py:attr:`mask_name`.
26
            Possible values are ``'center'`` or ``'mask'``.
27
        mask_name: If :py:attr:`mode` is ``'mask'``, name of the mask from
28
            which to extract the bounding box.
29
30
    Example:
31
        >>> import torchio
32
        >>> from torchio.tranforms import CropOrPad
33
        >>> subject = torchio.Subject(
34
        ...     torchio.Image('chest_ct', 'subject_a_ct.nii.gz', torchio.INTENSITY),
35
        ...     torchio.Image('heart_mask', 'subject_a_heart_seg.nii.gz', torchio.LABEL),
36
        ... )
37
        >>> sample = torchio.ImagesDataset([subject])[0]
38
        >>> sample['chest_ct'][torchio.DATA].shape
39
        torch.Size([1, 512, 512, 289])
40
        >>> transform = CropOrPad(
41
        ...     (120, 80, 180),
42
        ...     padding_mode='reflect',
43
        ...     mode='mask',
44
        ...     mask_name='heart_mask',
45
        ... )
46
        >>> transformed = transform(sample)
47
        >>> transformed['chest_ct'][torchio.DATA].shape
48
        torch.Size([1, 120, 80, 180])
49
    """
50
    def __init__(
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
51
            self,
52
            target_shape: Union[int, Tuple[int, int, int]],
53
            padding_mode: str = 'constant',
54
            padding_fill: Optional[float] = None,
55
            mode: str = 'center',
56
            mask_name: Optional[str] = None,
57
            ):
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block.
Loading history...
58
        super().__init__(target_shape)
59
        self.mode = mode
60
        self.padding_mode = padding_mode
61
        self.padding_fill = padding_fill
62
        if mode not in {'center', 'mask'}:
63
            message = f'Mode must be "center" or "mask", not "{mode}"'
64
            raise ValueError(message)
65
        if mode == 'mask':
66
            if mask_name is None:
67
                message = 'If mode is "mask", mask_name cannot be None'
68
                raise ValueError(message)
69
            self.mask_name = mask_name
70
            self.compute_crop_or_pad = self._compute_mask_center_crop_or_pad
71
        else:
72
            self.compute_crop_or_pad = self._compute_center_crop_or_pad
73
74
    @staticmethod
75
    def _bbox_mask(mask_volume: np.ndarray):
76
        """Return 6 coordinates of a 3D bounding box from a given mask.
77
78
        Taken from `this SO question <https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array>`_.
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (114/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
79
80
        Args:
81
            mask_volume: 3D NumPy array.
82
        """
83
        r = np.any(mask_volume, axis=(1, 2))
0 ignored issues
show
Coding Style Naming introduced by
Variable name "r" doesn't conform to snake_case naming style ('(([a-z_][a-z0-9_]2,)|(_[a-z0-9_]*)|(__[a-z][a-z0-9_]+__))$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
84
        c = np.any(mask_volume, axis=(0, 2))
0 ignored issues
show
Coding Style Naming introduced by
Variable name "c" doesn't conform to snake_case naming style ('(([a-z_][a-z0-9_]2,)|(_[a-z0-9_]*)|(__[a-z][a-z0-9_]+__))$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
85
        z = np.any(mask_volume, axis=(0, 1))
0 ignored issues
show
Coding Style Naming introduced by
Variable name "z" doesn't conform to snake_case naming style ('(([a-z_][a-z0-9_]2,)|(_[a-z0-9_]*)|(__[a-z][a-z0-9_]+__))$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
86
        rmin, rmax = np.where(r)[0][[0, -1]]
87
        cmin, cmax = np.where(c)[0][[0, -1]]
88
        zmin, zmax = np.where(z)[0][[0, -1]]
89
        return np.array([rmin, cmin, zmin]), np.array([rmax, cmax, zmax])
90
91
    @staticmethod
92
    def _get_sample_shape(sample: dict) -> Tuple[int]:
93
        """Return the shape of the first image in the sample."""
94
        check_consistent_shape(sample)
95
        for image_dict in sample.values():
96
            if not is_image_dict(image_dict):
97
                continue
98
            data = image_dict[DATA].shape[1:]  # remove channels dimension
99
            break
100
        return data
0 ignored issues
show
introduced by
The variable data does not seem to be defined for all execution paths.
Loading history...
101
102
    @staticmethod
103
    def _get_six_bounds_parameters(parameters: np.ndarray):
104
        r"""Compute bounds parameters for ITK filters.
105
106
        Args:
107
            parameters: Tuple :math:`(d, h, w)` with the number of voxels to be
108
                cropped or padded.
109
110
        Returns:
111
            Tuple :math:`(d_{ini}, d_{fin}, h_{ini}, h_{fin}, w_{ini}, w_{fin})`,
112
            where :math:`n_{ini} = \left \lceil \frac{n}{2} \right \rceil` and
113
            :math:`n_{fin} = \left \lfloor \frac{n}{2} \right \rfloor`.
114
115
        Example:
116
            >>> p = np.array((4, 0, 7))
117
            >>> _get_six_bounds_parameters(p)
118
            (2, 2, 0, 0, 4, 3)
119
120
        """
121
        parameters = parameters / 2
122
        result = []
123
        for n in parameters:
0 ignored issues
show
Coding Style Naming introduced by
Variable name "n" doesn't conform to snake_case naming style ('(([a-z_][a-z0-9_]2,)|(_[a-z0-9_]*)|(__[a-z][a-z0-9_]+__))$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
124
            ini, fin = int(np.ceil(n)), int(np.floor(n))
125
            result.extend([ini, fin])
126
        return tuple(result)
127
128
    def _compute_center_crop_or_pad(self, sample: dict):
129
        source_shape = self._get_sample_shape(sample)
130
        # The parent class turns the 3-element shape tuple (d, h, w)
131
        # into a 6-element bounds tuple (d, d, h, h, w, w)
132
        target_shape = np.array(self.bounds_parameters[::2])
133
        diff_shape = target_shape - source_shape
134
135
        cropping = -np.minimum(diff_shape, 0)
136
        if cropping.any():
137
            cropping_params = self._get_six_bounds_parameters(cropping)
138
        else:
139
            cropping_params = None
140
141
        padding = np.maximum(diff_shape, 0)
142
        if padding.any():
143
            padding_params = self._get_six_bounds_parameters(padding)
144
        else:
145
            padding_params = None
146
147
        return padding_params, cropping_params
148
149
    def _compute_mask_center_crop_or_pad(self, sample: dict):
0 ignored issues
show
Comprehensibility introduced by
This function exceeds the maximum number of variables (20/15).
Loading history...
150
        if self.mask_name not in sample:
151
            message = (
152
                f'Mask name "{self.mask_name}"'
153
                f' not found in sample keys: {tuple(sample.keys())}'
154
            )
155
            raise KeyError(message)
156
        mask = sample[self.mask_name][DATA].numpy()
157
        # Original sample shape (from mask shape)
158
        sample_shape = np.squeeze(mask).shape
159
        # Calculate bounding box of the mask center
160
        bb_min, bb_max = self._bbox_mask(mask[0])
161
        # Coordinates of the mask center
162
        center_mask = (bb_max - bb_min) / 2 + bb_min
163
        # List of padding to do
164
        padding = []
165
        # Final cropping (after padding)
166
        cropping = []
167
        for dim, center_dim in enumerate(center_mask):
168
            # Compute coordinates of the target shape taken from the center of
169
            # the mask
170
            begin = center_dim - (self.bounds_parameters[2 * dim] / 2)
171
            end = center_dim + (self.bounds_parameters[2 * dim + 1] / 2)
172
            # Check if dimension needs padding (before or after)
173
            begin_pad = round(abs(min(begin, 0)))
174
            end_pad = round(max(end - sample_shape[dim], 0))
175
            # Check if cropping is needed
176
            begin_crop = round(max(begin, 0))
177
            end_crop = abs(round(min(end - sample_shape[dim], 0)))
178
            # Add padding values of the dim to the list
179
            padding.append(begin_pad)
180
            padding.append(end_pad)
181
            # Add the slice of the dimension to take
182
            cropping.append(begin_crop)
183
            cropping.append(end_crop)
184
        # Conversion for SimpleITK compatibility
185
        padding_params = np.asarray(padding, dtype=np.uint).tolist()
186
        cropping_params = np.asarray(cropping, dtype=np.uint).tolist()
187
        return padding_params, cropping_params
188
189
    def apply_transform(self, sample: dict) -> dict:
190
        padding_params, cropping_params = self.compute_crop_or_pad(sample)
191
        padding_kwargs = dict(
192
            padding_mode=self.padding_mode, fill=self.padding_fill)
193
        if padding_params is not None:
194
            sample = Pad(padding_params, **padding_kwargs)(sample)
195
        if cropping_params is not None:
196
            sample = Crop(cropping_params)(sample)
197
        return sample
198
199
200
@deprecated('CenterCropOrPad is deprecated. Use CropOrPad instead.')
0 ignored issues
show
Bug introduced by
The method bounds_function which was declared abstract in the super-class BoundsTransform
was not overridden.

Methods which raise NotImplementedError should be overridden in concrete child classes.

Loading history...
201
class CenterCropOrPad(CropOrPad):
202
    """Crop and/or pad an image to a target shape.
203
    Args:
204
        target_shape: Tuple :math:`(D, H, W)`. If a single value :math:`N` is
205
            provided, then :math:`D = H = W = N`.
206
        padding_mode: See :py:class:`~torchio.transforms.Pad`.
207
        padding_fill: Same as :attr:`fill` in
208
            :py:class:`~torchio.transforms.Pad`.
209
    """
210
211
    def __init__(
212
            self,
213
            target_shape: Union[int, Tuple[int, int, int]],
214
            padding_mode: str = 'constant',
215
            padding_fill: Optional[float] = None,
216
            ):
0 ignored issues
show
Coding Style introduced by
Wrong hanging indentation before block.
Loading history...
217
        super().__init__(
218
            target_shape=target_shape,
219
            padding_mode=padding_mode,
220
            padding_fill=padding_fill,
221
            mode='center',
222
        )
223