Completed
Pull Request — master (#401)
by Fernando
20:26 queued 20:26
created

torchio.transforms.preprocessing.spatial.ensure_shape_multiple   A

Complexity

Total Complexity 4

Size/Duplication

Total Lines 63
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 27
dl 0
loc 63
rs 10
c 0
b 0
f 0
wmc 4

2 Methods

Rating   Name   Duplication   Size   Complexity  
A EnsureShapeMultiple.__init__() 0 12 2
A EnsureShapeMultiple.apply_transform() 0 7 2
1
from typing import Union, Optional
2
3
import numpy as np
4
5
from ... import SpatialTransform
6
from ....utils import to_tuple
7
from ....data.subject import Subject
8
from ....typing import TypeTripletInt
9
from .crop_or_pad import CropOrPad
10
11
12
class EnsureShapeMultiple(SpatialTransform):
13
    """Crop or pad an image to a shape that is a multiple of :math:`N`.
14
15
    Args:
16
        target_multiple: Tuple :math:`(w, h, d)`. If a single value :math:`n` is
17
            provided, then :math:`w = h = d = n`.
18
        method: Either ``'crop'`` or ``'pad'``.
19
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
20
            keyword arguments.
21
22
    Example:
23
        >>> import torchio as tio
24
        >>> image = tio.datasets.Colin27().t1
25
        >>> image.shape
26
        (1, 181, 217, 181)
27
        >>> transform = tio.EnsureShapeMultiple(8, method='pad')
28
        >>> transformed = transform(image)
29
        >>> transformed.shape
30
        (1, 184, 224, 184)
31
        >>> transform = tio.EnsureShapeMultiple(8, method='crop')
32
        >>> transformed = transform(image)
33
        >>> transformed.shape
34
        (1, 176, 216, 176)
35
        >>> image_2d = image.data[..., :1]
36
        >>> image_2d.shape
37
        torch.Size([1, 181, 217, 1])
38
        >>> transformed = transform(image_2d)
39
        >>> transformed.shape
40
        torch.Size([1, 176, 216, 1])
41
42
    """
43
    def __init__(
44
            self,
45
            target_multiple: Union[int, TypeTripletInt],
46
            *,
47
            method: Optional[str] = 'pad',
48
            **kwargs
49
            ):
50
        super().__init__(**kwargs)
51
        self.target_multiple = np.array(to_tuple(target_multiple, 3))
52
        if method not in ('crop', 'pad'):
53
            raise ValueError('Method must be "crop" or "pad"')
54
        self.method = method
55
56
    def apply_transform(self, subject: Subject) -> Subject:
57
        source_shape = np.array(subject.spatial_shape, np.uint16)
58
        function = np.floor if self.method == 'crop' else np.ceil
59
        integer_ratio = function(source_shape / self.target_multiple)
60
        target_shape = integer_ratio * self.target_multiple
61
        target_shape = np.maximum(target_shape, 1)
62
        return CropOrPad(target_shape.astype(int))(subject)
63