Completed
Push — master ( 68ed91...56386c )
by Fernando
01:25
created

CropOrPad._compute_mask_center_crop_or_pad()   C

Complexity

Conditions 9

Size

Total Lines 70
Code Lines 47

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 47
dl 0
loc 70
rs 6.4012
c 0
b 0
f 0
cc 9
nop 2

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
import warnings
2
from typing import Union, Tuple, Optional
3
4
import numpy as np
5
6
from .pad import Pad
7
from .crop import Crop
8
from .bounds_transform import BoundsTransform
9
from ...transform import TypeTripletInt, TypeSixBounds
10
from ....data.subject import Subject
11
12
13
class CropOrPad(BoundsTransform):
14
    """Crop and/or pad an image to a target shape.
15
16
    This transform modifies the affine matrix associated to the volume so that
17
    physical positions of the voxels are maintained.
18
19
    Args:
20
        target_shape: Tuple :math:`(W, H, D)`. If a single value :math:`N` is
21
            provided, then :math:`W = H = D = N`.
22
        padding_mode: Same as :attr:`padding_mode` in
23
            :class:`~torchio.transforms.Pad`.
24
        mask_name: If ``None``, the centers of the input and output volumes
25
            will be the same.
26
            If a string is given, the output volume center will be the center
27
            of the bounding box of non-zero values in the image named
28
            :attr:`mask_name`.
29
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
30
            keyword arguments.
31
32
    Example:
33
        >>> import torchio as tio
34
        >>> subject = tio.Subject(
35
        ...     chest_ct=tio.ScalarImage('subject_a_ct.nii.gz'),
36
        ...     heart_mask=tio.LabelMap('subject_a_heart_seg.nii.gz'),
37
        ... )
38
        >>> subject.chest_ct.shape
39
        torch.Size([1, 512, 512, 289])
40
        >>> transform = tio.CropOrPad(
41
        ...     (120, 80, 180),
42
        ...     mask_name='heart_mask',
43
        ... )
44
        >>> transformed = transform(subject)
45
        >>> transformed.chest_ct.shape
46
        torch.Size([1, 120, 80, 180])
47
    """
48
    def __init__(
49
            self,
50
            target_shape: Union[int, TypeTripletInt],
51
            padding_mode: Union[str, float] = 0,
52
            mask_name: Optional[str] = None,
53
            **kwargs
54
            ):
55
        super().__init__(target_shape, **kwargs)
56
        self.padding_mode = padding_mode
57
        if mask_name is not None and not isinstance(mask_name, str):
58
            message = (
59
                'If mask_name is not None, it must be a string,'
60
                f' not {type(mask_name)}'
61
            )
62
            raise ValueError(message)
63
        self.mask_name = mask_name
64
        if self.mask_name is None:
65
            self.compute_crop_or_pad = self._compute_center_crop_or_pad
66
        else:
67
            if not isinstance(mask_name, str):
68
                message = (
69
                    'If mask_name is not None, it must be a string,'
70
                    f' not {type(mask_name)}'
71
                )
72
                raise ValueError(message)
73
            self.compute_crop_or_pad = self._compute_mask_center_crop_or_pad
74
75
    @staticmethod
76
    def _bbox_mask(mask_volume: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
77
        """Return 6 coordinates of a 3D bounding box from a given mask.
78
79
        Taken from `this SO question <https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array>`_.
80
81
        Args:
82
            mask_volume: 3D NumPy array.
83
        """  # noqa: E501
84
        i_any = np.any(mask_volume, axis=(1, 2))
85
        j_any = np.any(mask_volume, axis=(0, 2))
86
        k_any = np.any(mask_volume, axis=(0, 1))
87
        i_min, i_max = np.where(i_any)[0][[0, -1]]
88
        j_min, j_max = np.where(j_any)[0][[0, -1]]
89
        k_min, k_max = np.where(k_any)[0][[0, -1]]
90
        bb_min = np.array([i_min, j_min, k_min])
91
        bb_max = np.array([i_max, j_max, k_max]) + 1
92
        return bb_min, bb_max
93
94
    @staticmethod
95
    def _get_six_bounds_parameters(
96
            parameters: np.ndarray,
97
            ) -> TypeSixBounds:
98
        r"""Compute bounds parameters for ITK filters.
99
100
        Args:
101
            parameters: Tuple :math:`(w, h, d)` with the number of voxels to be
102
                cropped or padded.
103
104
        Returns:
105
            Tuple :math:`(w_{ini}, w_{fin}, h_{ini}, h_{fin}, d_{ini}, d_{fin})`,
106
            where :math:`n_{ini} = \left \lceil \frac{n}{2} \right \rceil` and
107
            :math:`n_{fin} = \left \lfloor \frac{n}{2} \right \rfloor`.
108
109
        Example:
110
            >>> p = np.array((4, 0, 7))
111
            >>> CropOrPad._get_six_bounds_parameters(p)
112
            (2, 2, 0, 0, 4, 3)
113
        """  # noqa: E501
114
        parameters = parameters / 2
115
        result = []
116
        for number in parameters:
117
            ini, fin = int(np.ceil(number)), int(np.floor(number))
118
            result.extend([ini, fin])
119
        return tuple(result)
120
121
    @property
122
    def target_shape(self):
123
        return self.bounds_parameters[::2]
124
125
    def _compute_cropping_padding_from_shapes(
126
            self,
127
            source_shape: TypeTripletInt,
128
            target_shape: TypeTripletInt,
129
            ) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]:
130
        diff_shape = target_shape - source_shape
131
132
        cropping = -np.minimum(diff_shape, 0)
133
        if cropping.any():
134
            cropping_params = self._get_six_bounds_parameters(cropping)
135
        else:
136
            cropping_params = None
137
138
        padding = np.maximum(diff_shape, 0)
139
        if padding.any():
140
            padding_params = self._get_six_bounds_parameters(padding)
141
        else:
142
            padding_params = None
143
144
        return padding_params, cropping_params
145
146
    def _compute_center_crop_or_pad(
147
            self,
148
            subject: Subject,
149
            ) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]:
150
        source_shape = subject.spatial_shape
151
        # The parent class turns the 3-element shape tuple (w, h, d)
152
        # into a 6-element bounds tuple (w, w, h, h, d, d)
153
        target_shape = np.array(self.bounds_parameters[::2])
154
        parameters = self._compute_cropping_padding_from_shapes(
155
            source_shape, target_shape)
156
        padding_params, cropping_params = parameters
157
        return padding_params, cropping_params
158
159
    def _compute_mask_center_crop_or_pad(
160
            self,
161
            subject: Subject,
162
            ) -> Tuple[Optional[TypeSixBounds], Optional[TypeSixBounds]]:
163
        if self.mask_name not in subject:
164
            message = (
165
                f'Mask name "{self.mask_name}"'
166
                f' not found in subject keys "{tuple(subject.keys())}".'
167
                ' Using volume center instead'
168
            )
169
            warnings.warn(message, RuntimeWarning)
170
            return self._compute_center_crop_or_pad(subject=subject)
171
172
        mask = subject[self.mask_name].numpy()
173
174
        if not np.any(mask):
175
            message = (
176
                f'All values found in the mask "{self.mask_name}"'
177
                ' are zero. Using volume center instead'
178
            )
179
            warnings.warn(message, RuntimeWarning)
180
            return self._compute_center_crop_or_pad(subject=subject)
181
182
        # Let's assume that the center of first voxel is at coordinate 0.5
183
        # (which is typically not the case)
184
        subject_shape = subject.spatial_shape
185
        bb_min, bb_max = self._bbox_mask(mask[0])
186
        center_mask = np.mean((bb_min, bb_max), axis=0)
187
        padding = []
188
        cropping = []
189
        target_shape = np.array(self.target_shape)
190
191
        for dim in range(3):
192
            target_dim = target_shape[dim]
193
            center_dim = center_mask[dim]
194
            subject_dim = subject_shape[dim]
195
196
            center_on_index = not (center_dim % 1)
197
            target_even = not (target_dim % 2)
198
199
            # Approximation when the center cannot be computed exactly
200
            # The output will be off by half a voxel, but this is just an
201
            # implementation detail
202
            if target_even ^ center_on_index:
203
                center_dim -= 0.5
204
205
            begin = center_dim - target_dim / 2
206
            if begin >= 0:
207
                crop_ini = begin
208
                pad_ini = 0
209
            else:
210
                crop_ini = 0
211
                pad_ini = -begin
212
213
            end = center_dim + target_dim / 2
214
            if end <= subject_dim:
215
                crop_fin = subject_dim - end
216
                pad_fin = 0
217
            else:
218
                crop_fin = 0
219
                pad_fin = end - subject_dim
220
221
            padding.extend([pad_ini, pad_fin])
222
            cropping.extend([crop_ini, crop_fin])
223
        # Conversion for SimpleITK compatibility
224
        padding = np.asarray(padding, dtype=int)
225
        cropping = np.asarray(cropping, dtype=int)
226
        padding_params = tuple(padding.tolist()) if padding.any() else None
227
        cropping_params = tuple(cropping.tolist()) if cropping.any() else None
228
        return padding_params, cropping_params
229
230
    def apply_transform(self, subject: Subject) -> Subject:
231
        padding_params, cropping_params = self.compute_crop_or_pad(subject)
232
        padding_kwargs = {'padding_mode': self.padding_mode}
233
        if padding_params is not None:
234
            subject = Pad(padding_params, **padding_kwargs)(subject)
235
        if cropping_params is not None:
236
            subject = Crop(cropping_params)(subject)
237
        actual, target = subject.spatial_shape, self.target_shape
238
        assert actual == target, (actual, target)
239
        return subject
240