1
|
|
|
from typing import Union, Optional |
2
|
|
|
|
3
|
|
|
import numpy as np |
4
|
|
|
|
5
|
|
|
from ... import SpatialTransform |
6
|
|
|
from ....utils import to_tuple |
7
|
|
|
from ....data.subject import Subject |
8
|
|
|
from ....typing import TypeTripletInt |
9
|
|
|
from .crop_or_pad import CropOrPad |
10
|
|
|
|
11
|
|
|
|
12
|
|
|
class EnsureShapeMultiple(SpatialTransform): |
13
|
|
|
"""Crop or pad an image to a shape that is a multiple of :math:`N`. |
14
|
|
|
|
15
|
|
|
Args: |
16
|
|
|
target_multiple: Tuple :math:`(w, h, d)`. If a single value :math:`n` is |
17
|
|
|
provided, then :math:`w = h = d = n`. |
18
|
|
|
method: Either ``'crop'`` or ``'pad'``. |
19
|
|
|
**kwargs: See :class:`~torchio.transforms.Transform` for additional |
20
|
|
|
keyword arguments. |
21
|
|
|
|
22
|
|
|
Example: |
23
|
|
|
>>> import torchio as tio |
24
|
|
|
>>> image = tio.datasets.Colin27().t1 |
25
|
|
|
>>> image.shape |
26
|
|
|
(1, 181, 217, 181) |
27
|
|
|
>>> transform = tio.EnsureShapeMultiple(8, method='pad') |
28
|
|
|
>>> transformed = transform(image) |
29
|
|
|
>>> transformed.shape |
30
|
|
|
(1, 184, 224, 184) |
31
|
|
|
>>> transform = tio.EnsureShapeMultiple(8, method='crop') |
32
|
|
|
>>> transformed = transform(image) |
33
|
|
|
>>> transformed.shape |
34
|
|
|
(1, 176, 216, 176) |
35
|
|
|
>>> image_2d = image.data[..., :1] |
36
|
|
|
>>> image_2d.shape |
37
|
|
|
torch.Size([1, 181, 217, 1]) |
38
|
|
|
>>> transformed = transform(image_2d) |
39
|
|
|
>>> transformed.shape |
40
|
|
|
torch.Size([1, 176, 216, 1]) |
41
|
|
|
|
42
|
|
|
""" |
43
|
|
|
def __init__( |
44
|
|
|
self, |
45
|
|
|
target_multiple: Union[int, TypeTripletInt], |
46
|
|
|
*, |
47
|
|
|
method: Optional[str] = 'pad', |
48
|
|
|
**kwargs |
49
|
|
|
): |
50
|
|
|
super().__init__(**kwargs) |
51
|
|
|
self.target_multiple = np.array(to_tuple(target_multiple, 3)) |
52
|
|
|
if method not in ('crop', 'pad'): |
53
|
|
|
raise ValueError('Method must be "crop" or "pad"') |
54
|
|
|
self.method = method |
55
|
|
|
|
56
|
|
|
def apply_transform(self, subject: Subject) -> Subject: |
57
|
|
|
source_shape = np.array(subject.spatial_shape, np.uint16) |
58
|
|
|
function = np.floor if self.method == 'crop' else np.ceil |
59
|
|
|
integer_ratio = function(source_shape / self.target_multiple) |
60
|
|
|
target_shape = integer_ratio * self.target_multiple |
61
|
|
|
target_shape = np.maximum(target_shape, 1) |
62
|
|
|
return CropOrPad(target_shape.astype(int))(subject) |
63
|
|
|
|