Passed
Pull Request — master (#182)
by Fernando
55s
created

GridSampler.get_patches_locations()   A

Complexity

Conditions 3

Size

Total Lines 29
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 19
nop 3
dl 0
loc 29
rs 9.45
c 0
b 0
f 0
1
from typing import Union
2
3
import numpy as np
4
from torch.utils.data import Dataset
5
6
from ...utils import to_tuple
7
from ...torchio import LOCATION, TypeTuple, TypeTripletInt
8
from ..subject import Subject
9
from ..sampler.sampler import PatchSampler
10
11
12
class GridSampler(PatchSampler, Dataset):
13
    r"""Extract patches across a whole volume.
14
15
    Grid samplers are useful to perform inference using all patches from a
16
    volume. It is often used with a :py:class:`~torchio.data.GridAggregator`.
17
18
    Args:
19
        sample: Instance of :py:class:`~torchio.data.subject.Subject`
20
            from which patches will be extracted.
21
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
22
            of size :math:`d \times h \times w`.
23
            If a single number :math:`n` is provided,
24
            :math:`d = h = w = n`.
25
        patch_overlap: Tuple of even integers :math:`(d_o, h_o, w_o)` specifying
26
            the overlap between patches for dense inference. If a single number
27
            :math:`n` is provided, :math:`d_o = h_o = w_o = n`.
28
        padding_mode: Same as :attr:`padding_mode` in
29
            :py:class:`~torchio.transforms.Pad`. If ``None``, the volume will
30
            not be padded before sampling and patches at the border will not be
31
            cropped by the aggregator. Otherwise, the volume will be padded with
32
            :math:`\left(\frac{d_o}{2}, \frac{h_o}{2}, \frac{w_o}{2}\right)`
33
            on each side before sampling. If the sampler is passed to a
34
            :py:class:`~torchio.data.GridAggregator`, it will crop the output
35
            to its original size.
36
37
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
38
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
39
        information about patch based sampling. Note that
40
        :py:attr:`patch_overlap` is twice :py:attr:`border` in NiftyNet
41
        tutorial.
42
    """
43
    def __init__(
44
            self,
45
            sample: Subject,
46
            patch_size: TypeTuple,
47
            patch_overlap: TypeTuple = (0, 0, 0),
48
            padding_mode: Union[str, float, None] = None,
49
            ):
50
        self.sample = sample
51
        self.patch_overlap = np.array(to_tuple(patch_overlap, length=3))
52
        self.padding_mode = padding_mode
53
        if padding_mode is not None:
54
            from ...transforms import Pad
55
            border = self.patch_overlap // 2
56
            padding = border.repeat(2)
57
            pad = Pad(padding, padding_mode=padding_mode)
58
            self.sample = pad(self.sample)
59
        PatchSampler.__init__(self, patch_size)
60
        sizes = self.sample.spatial_shape, self.patch_size, self.patch_overlap
61
        self.parse_sizes(*sizes)
62
        self.locations = self.get_patches_locations(*sizes)
63
64
    def __len__(self):
65
        return len(self.locations)
66
67
    def __getitem__(self, index):
68
        # Assume 3D
69
        location = self.locations[index]
70
        index_ini = location[:3]
71
        index_fin = location[3:]
72
        cropped_sample = self.extract_patch(self.sample, index_ini, index_fin)
73
        cropped_sample[LOCATION] = location
74
        return cropped_sample
75
76
    @staticmethod
77
    def parse_sizes(
78
            image_size: TypeTripletInt,
79
            patch_size: TypeTripletInt,
80
            patch_overlap: TypeTripletInt,
81
            ) -> None:
82
        image_size = np.array(image_size)
83
        patch_size = np.array(patch_size)
84
        patch_overlap = np.array(patch_overlap)
85
        if np.any(patch_size > image_size):
86
            message = (
87
                f'Patch size {tuple(patch_size)} cannot be'
88
                f' larger than image size {tuple(image_size)}'
89
            )
90
            raise ValueError(message)
91
        if np.any(patch_overlap >= patch_size):
92
            message = (
93
                f'Patch overlap {tuple(patch_overlap)} must be smaller'
94
                f' than patch size {tuple(image_size)}'
95
            )
96
            raise ValueError(message)
97
        if np.any(patch_overlap % 2):
98
            message = (
99
                'Patch overlap must be a tuple of even integers,'
100
                f' not {tuple(patch_overlap)}'
101
            )
102
            raise ValueError(message)
103
104
    def extract_patch(
105
            self,
106
            sample: Subject,
107
            index_ini: TypeTripletInt,
108
            index_fin: TypeTripletInt,
109
            ) -> Subject:
110
        crop = self.get_crop_transform(
111
            sample.spatial_shape,
112
            index_ini,
113
            index_fin - index_ini,
114
        )
115
        cropped_sample = crop(sample)
116
        return cropped_sample
117
118
    @staticmethod
119
    def get_patches_locations(
120
            image_size: TypeTripletInt,
121
            patch_size: TypeTripletInt,
122
            patch_overlap: TypeTripletInt,
123
            ) -> np.ndarray:
124
        # Example with image_size 10, patch_size 5, overlap 2:
125
        # [0 1 2 3 4 5 6 7 8 9]
126
        # [0 0 0 0 0]
127
        #       [1 1 1 1 1]
128
        #           [2 2 2 2 2]
129
        # Locations:
130
        # [[0, 5],
131
        #  [3, 8],
132
        #  [5, 10]]
133
        indices = []
134
        zipped = zip(image_size, patch_size, patch_overlap)
135
        for im_size_dim, patch_size_dim, patch_overlap_dim in zipped:
136
            end = im_size_dim + 1 - patch_size_dim
137
            step = patch_size_dim - patch_overlap_dim
138
            indices_dim = list(range(0, end, step))
139
            if indices_dim[-1] != im_size_dim - patch_size_dim:
140
                indices_dim.append(im_size_dim - patch_size_dim)
141
            indices.append(indices_dim)
142
        indices_ini = np.array(np.meshgrid(*indices)).reshape(3, -1).T
143
        indices_ini = np.unique(indices_ini, axis=0)
144
        indices_fin = indices_ini + np.array(patch_size)
145
        locations = np.hstack((indices_ini, indices_fin))
146
        return np.array(sorted(locations.tolist()))
147