Completed
Pull Request — master (#401)
by Fernando
01:02
created

torchio.transforms.preprocessing.spatial.ensure_shape_multiple   A

Complexity

Total Complexity 4

Size/Duplication

Total Lines 54
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 25
dl 0
loc 54
rs 10
c 0
b 0
f 0
wmc 4

2 Methods

Rating   Name   Duplication   Size   Complexity  
A EnsureShapeMultiple.__init__() 0 11 2
A EnsureShapeMultiple.apply_transform() 0 6 2
1
from typing import Union, Optional
2
3
import numpy as np
4
5
from ... import SpatialTransform
6
from ....utils import to_tuple
7
from ....data.subject import Subject
8
from ....typing import TypeTripletInt
9
from .crop_or_pad import CropOrPad
10
11
12
class EnsureShapeMultiple(SpatialTransform):
13
    """Crop and/or pad an image to a shape that is a multiple of :math:`N`.
14
15
    Args:
16
        target_multiple: Tuple :math:`(w, h, d)`. If a single value :math:`n` is
17
            provided, then :math:`w = h = d = n`.
18
        method: Either ``'crop'`` or ``'pad'``.
19
        **kwargs: See :class:`~torchio.transforms.Transform` for additional
20
            keyword arguments.
21
22
    Example:
23
        >>> import torchio as tio
24
        >>> image = torchio.datasets.FPG().t1
25
        >>> image.shape
26
        (1, 181, 217, 181)
27
        >>> transform = tio.EnsureShapeMultiple(8, method='pad')
28
        >>> transformed = transform(image)
29
        >>> transformed.shape
30
        (1, 184, 224, 184)
31
        >>> transform = tio.EnsureShapeMultiple(8, method='crop')
32
        >>> transformed = transform(image)
33
        >>> transformed.shape
34
        (1, 176, 216, 176)
35
    """
36
    def __init__(
37
            self,
38
            target_multiple: Union[int, TypeTripletInt],
39
            method: Optional[str] = 'pad',
40
            **kwargs
41
            ):
42
        super().__init__(**kwargs)
43
        self.target_multiple = np.array(to_tuple(target_multiple, 3))
44
        if method not in ('crop', 'pad'):
45
            raise ValueError('Method must be "crop" or "pad"')
46
        self.method = method
47
48
    def apply_transform(self, subject: Subject) -> Subject:
49
        source_shape = np.array(subject.spatial_shape, np.uint16)
50
        function = np.floor if self.method == 'crop' else np.ceil
51
        integer_ratio = function(source_shape / self.target_multiple)
52
        target_shape = integer_ratio * self.target_multiple
53
        return CropOrPad(target_shape.astype(int))(subject)
54