|
1
|
|
|
from typing import Union, Optional |
|
2
|
|
|
|
|
3
|
|
|
import numpy as np |
|
4
|
|
|
|
|
5
|
|
|
from ... import SpatialTransform |
|
6
|
|
|
from ....utils import to_tuple |
|
7
|
|
|
from ....data.subject import Subject |
|
8
|
|
|
from ....typing import TypeTripletInt |
|
9
|
|
|
from .crop_or_pad import CropOrPad |
|
10
|
|
|
|
|
11
|
|
|
|
|
12
|
|
|
class EnsureShapeMultiple(SpatialTransform): |
|
13
|
|
|
"""Crop or pad an image to a shape that is a multiple of :math:`N`. |
|
14
|
|
|
|
|
15
|
|
|
Args: |
|
16
|
|
|
target_multiple: Tuple :math:`(w, h, d)`. If a single value :math:`n` is |
|
17
|
|
|
provided, then :math:`w = h = d = n`. |
|
18
|
|
|
method: Either ``'crop'`` or ``'pad'``. |
|
19
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
|
20
|
|
|
keyword arguments. |
|
21
|
|
|
|
|
22
|
|
|
Example: |
|
23
|
|
|
>>> import torchio as tio |
|
24
|
|
|
>>> image = tio.datasets.Colin27().t1 |
|
25
|
|
|
>>> image.shape |
|
26
|
|
|
(1, 181, 217, 181) |
|
27
|
|
|
>>> transform = tio.EnsureShapeMultiple(8, method='pad') |
|
28
|
|
|
>>> transformed = transform(image) |
|
29
|
|
|
>>> transformed.shape |
|
30
|
|
|
(1, 184, 224, 184) |
|
31
|
|
|
>>> transform = tio.EnsureShapeMultiple(8, method='crop') |
|
32
|
|
|
>>> transformed = transform(image) |
|
33
|
|
|
>>> transformed.shape |
|
34
|
|
|
(1, 176, 216, 176) |
|
35
|
|
|
>>> image_2d = image.data[..., :1] |
|
36
|
|
|
>>> image_2d.shape |
|
37
|
|
|
torch.Size([1, 181, 217, 1]) |
|
38
|
|
|
>>> transformed = transform(image_2d) |
|
39
|
|
|
>>> transformed.shape |
|
40
|
|
|
torch.Size([1, 176, 216, 1]) |
|
41
|
|
|
|
|
42
|
|
|
""" |
|
43
|
|
|
def __init__( |
|
44
|
|
|
self, |
|
45
|
|
|
target_multiple: Union[int, TypeTripletInt], |
|
46
|
|
|
*, |
|
47
|
|
|
method: Optional[str] = 'pad', |
|
48
|
|
|
**kwargs |
|
49
|
|
|
): |
|
50
|
|
|
super().__init__(**kwargs) |
|
51
|
|
|
self.target_multiple = np.array(to_tuple(target_multiple, 3)) |
|
52
|
|
|
if method not in ('crop', 'pad'): |
|
53
|
|
|
raise ValueError('Method must be "crop" or "pad"') |
|
54
|
|
|
self.method = method |
|
55
|
|
|
|
|
56
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
|
57
|
|
|
source_shape = np.array(subject.spatial_shape, np.uint16) |
|
58
|
|
|
function = np.floor if self.method == 'crop' else np.ceil |
|
59
|
|
|
integer_ratio = function(source_shape / self.target_multiple) |
|
60
|
|
|
target_shape = integer_ratio * self.target_multiple |
|
61
|
|
|
target_shape = np.maximum(target_shape, 1) |
|
62
|
|
|
return CropOrPad(target_shape.astype(int))(subject) |
|
63
|
|
|
|