Passed
Pull Request — master (#182)
by Fernando
01:18
created

torchio.data.inference.grid_sampler.crop()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 7
nop 3
dl 0
loc 8
rs 10
c 0
b 0
f 0
1
import numpy as np
2
from torch.utils.data import Dataset
3
4
from ..sampler.sampler import PatchSampler
5
from ...utils import to_tuple
6
from ...torchio import LOCATION, TypeTuple, TypeTripletInt
7
from ..subject import Subject
8
9
10
class GridSampler(PatchSampler, Dataset):
11
    r"""Extract patches across a whole volume.
12
13
    Grid samplers are useful to perform inference using all patches from a
14
    volume. It is often used with a :py:class:`~torchio.data.GridAggregator`.
15
16
    Args:
17
        sample: Instance of :py:class:`~torchio.data.subject.Subject`
18
            from which patches will be extracted.
19
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
20
            of size :math:`d \times h \times w`.
21
            If a single number :math:`n` is provided,
22
            :math:`d = h = w = n`.
23
        patch_overlap: Tuple of even integers :math:`(d_o, h_o, w_o)` specifying
24
            the overlap between patches for dense inference. If a single number
25
            :math:`n` is provided, :math:`d_o = h_o = w_o = n`.
26
27
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
28
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
29
        information about patch based sampling. Note that
30
        :py:attr:`patch_overlap` is twice :py:attr:`border` in NiftyNet
31
        tutorial.
32
    """
33
    def __init__(
34
            self,
35
            sample: Subject,
36
            patch_size: TypeTuple,
37
            patch_overlap: TypeTuple = (0, 0, 0),
38
            ):
39
        self.sample = sample
40
        PatchSampler.__init__(self, patch_size)
41
        self.patch_overlap = to_tuple(patch_overlap, length=3)
42
        sizes = self.sample.spatial_shape, self.patch_size, self.patch_overlap
43
        self.parse_sizes(*sizes)
44
        self.locations = self.get_patches_locations(*sizes)
45
46
    def __len__(self):
47
        return len(self.locations)
48
49
    def __getitem__(self, index):
50
        # Assume 3D
51
        location = self.locations[index]
52
        index_ini = location[:3]
53
        index_fin = location[3:]
54
        cropped_sample = self.extract_patch(self.sample, index_ini, index_fin)
55
        cropped_sample[LOCATION] = location
56
        return cropped_sample
57
58
    @staticmethod
59
    def parse_sizes(
60
            image_size: TypeTripletInt,
61
            patch_size: TypeTripletInt,
62
            patch_overlap: TypeTripletInt,
63
            ) -> None:
64
        image_size = np.array(image_size)
65
        patch_size = np.array(patch_size)
66
        patch_overlap = np.array(patch_overlap)
67
        if np.any(patch_size > image_size):
68
            message = (
69
                f'Patch size {tuple(patch_size)} cannot be'
70
                f' larger than image size {tuple(image_size)}'
71
            )
72
            raise ValueError(message)
73
        if np.any(patch_overlap >= patch_size):
74
            message = (
75
                f'Patch overlap {tuple(patch_overlap)} must be smaller'
76
                f' than patch size {tuple(image_size)}'
77
            )
78
            raise ValueError(message)
79
        if np.any(patch_overlap % 2):
80
            message = (
81
                'Patch overlap must be a tuple of even integers,'
82
                f' not {tuple(patch_overlap)}'
83
            )
84
            raise ValueError(message)
85
86
    def extract_patch(
87
            self,
88
            sample: Subject,
89
            index_ini: TypeTripletInt,
90
            index_fin: TypeTripletInt,
91
            ) -> Subject:
92
        crop = self.get_crop_transform(
93
            sample.spatial_shape,
94
            index_ini,
95
            index_fin - index_ini,
96
        )
97
        cropped_sample = crop(sample)
98
        return cropped_sample
99
100
    @staticmethod
101
    def get_patches_locations(
102
            image_size: TypeTripletInt,
103
            patch_size: TypeTripletInt,
104
            patch_overlap: TypeTripletInt,
105
            ) -> np.ndarray:
106
        indices = []
107
        zipped = zip(image_size, patch_size, patch_overlap)
108
        for im_size_dim, patch_size_dim, patch_overlap_dim in zipped:
109
            end = im_size_dim + 1 - patch_size_dim
110
            step = patch_size_dim - patch_overlap_dim
111
            indices_dim = list(range(0, end, step))
112
            if im_size_dim % step:
113
                indices_dim.append(im_size_dim - patch_size_dim)
114
            indices.append(indices_dim)
115
        indices_ini = np.array(np.meshgrid(*indices)).reshape(3, -1).T
116
        indices_ini = np.unique(indices_ini, axis=0)
117
        indices_fin = indices_ini + np.array(patch_size)
118
        locations = np.hstack((indices_ini, indices_fin))
119
        return np.array(sorted(locations.tolist()))
120