Passed
Pull Request — master (#520)
by Fernando
01:15
created

torchio.data.sampler.grid   A

Complexity

Total Complexity 19

Size/Duplication

Total Lines 176
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 19
eloc 97
dl 0
loc 176
rs 10
c 0
b 0
f 0

8 Methods

Rating   Name   Duplication   Size   Complexity  
A GridSampler._get_patches_locations() 0 29 3
A GridSampler.__len__() 0 2 1
A GridSampler._pad() 0 8 2
A GridSampler._compute_locations() 0 6 2
A GridSampler.__getitem__() 0 7 1
A GridSampler._generate_patches() 0 11 2
A GridSampler._parse_sizes() 0 27 4
A GridSampler.__init__() 0 16 4
1
from typing import Union, Generator, Optional
2
3
import numpy as np
4
5
from ...utils import to_tuple
6
from ...constants import LOCATION
7
from ...data.subject import Subject
8
from ...typing import TypePatchSize
9
from ...typing import TypeTripletInt
10
from .sampler import PatchSampler
11
12
13
class GridSampler(PatchSampler):
14
    r"""Extract patches across a whole volume.
15
16
    Grid samplers are useful to perform inference using all patches from a
17
    volume. It is often used with a :class:`~torchio.data.GridAggregator`.
18
19
    Args:
20
        subject: Instance of :class:`~torchio.data.Subject`
21
            from which patches will be extracted. This argument should only be
22
            used before instantiating a :class:`~torchio.data.GridAggregator`,
23
            or to precompute the number of patches that would be generated from
24
            a subject.
25
        patch_size: Tuple of integers :math:`(w, h, d)` to generate patches
26
            of size :math:`w \times h \times d`.
27
            If a single number :math:`n` is provided,
28
            :math:`w = h = d = n`.
29
            This argument is mandatory (it is a keyword argument for backward
30
            compatibility).
31
        patch_overlap: Tuple of even integers :math:`(w_o, h_o, d_o)`
32
            specifying the overlap between patches for dense inference. If a
33
            single number :math:`n` is provided, :math:`w_o = h_o = d_o = n`.
34
        padding_mode: Same as :attr:`padding_mode` in
35
            :class:`~torchio.transforms.Pad`. If ``None``, the volume will not
36
            be padded before sampling and patches at the border will not be
37
            cropped by the aggregator.
38
            Otherwise, the volume will be padded with
39
            :math:`\left(\frac{w_o}{2}, \frac{h_o}{2}, \frac{d_o}{2} \right)`
40
            on each side before sampling. If the sampler is passed to a
41
            :class:`~torchio.data.GridAggregator`, it will crop the output
42
            to its original size.
43
44
    Example::
45
46
        >>> import torchio as tio
47
        >>> sampler = tio.GridSampler(patch_size=88)
48
        >>> colin = tio.datasets.Colin27()
49
        >>> for i, patch in enumerate(sampler(colin)):
50
        ...     patch.t1.save(f'patch_{i}.nii.gz')
51
        ...
52
        >>> # To figure out the number of patches beforehand:
53
        >>> sampler = tio.GridSampler(subject=colin, patch_size=88)
54
        >>> len(sampler)
55
        8
56
57
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
58
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
59
        information about patch based sampling. Note that
60
        :attr:`patch_overlap` is twice :attr:`border` in NiftyNet
61
        tutorial.
62
    """
63
    def __init__(
64
            self,
65
            subject: Optional[Subject] = None,
66
            patch_size: TypePatchSize = None,
67
            patch_overlap: TypePatchSize = (0, 0, 0),
68
            padding_mode: Union[str, float, None] = None,
69
            ):
70
        if patch_size is None:
71
            raise ValueError('A value for patch_size must be given')
72
        super().__init__(patch_size)
73
        self.patch_overlap = np.array(to_tuple(patch_overlap, length=3))
74
        self.padding_mode = padding_mode
75
        if subject is not None and not isinstance(subject, Subject):
76
            raise ValueError('The subject argument must be None or Subject')
77
        self.subject = self._pad(subject)
78
        self.locations = self._compute_locations(self.subject)
79
80
    def __len__(self):
81
        return len(self.locations)
82
83
    def __getitem__(self, index):
84
        # Assume 3D
85
        location = self.locations[index]
86
        index_ini = location[:3]
87
        cropped_subject = self.crop(self.subject, index_ini, self.patch_size)
88
        cropped_subject[LOCATION] = location
89
        return cropped_subject
90
91
    def _pad(self, subject: Subject) -> Subject:
92
        if self.padding_mode is not None:
93
            from ...transforms import Pad
94
            border = self.patch_overlap // 2
95
            padding = border.repeat(2)
96
            pad = Pad(padding, padding_mode=self.padding_mode)
97
            subject = pad(subject)
98
        return subject
99
100
    def _compute_locations(self, subject: Subject):
101
        if subject is None:
102
            return None
103
        sizes = subject.spatial_shape, self.patch_size, self.patch_overlap
104
        self._parse_sizes(*sizes)
105
        return self._get_patches_locations(*sizes)
106
107
    def _generate_patches(
108
            self,
109
            subject: Subject,
110
            ) -> Generator[Subject, None, None]:
111
        subject = self._pad(subject)
112
        sizes = subject.spatial_shape, self.patch_size, self.patch_overlap
113
        self._parse_sizes(*sizes)
114
        locations = self._get_patches_locations(*sizes)
115
        for location in locations:
116
            index_ini = location[:3]
117
            yield self.extract_patch(subject, index_ini)
118
119
    @staticmethod
120
    def _parse_sizes(
121
            image_size: TypeTripletInt,
122
            patch_size: TypeTripletInt,
123
            patch_overlap: TypeTripletInt,
124
            ) -> None:
125
        image_size = np.array(image_size)
126
        patch_size = np.array(patch_size)
127
        patch_overlap = np.array(patch_overlap)
128
        if np.any(patch_size > image_size):
129
            message = (
130
                f'Patch size {tuple(patch_size)} cannot be'
131
                f' larger than image size {tuple(image_size)}'
132
            )
133
            raise ValueError(message)
134
        if np.any(patch_overlap >= patch_size):
135
            message = (
136
                f'Patch overlap {tuple(patch_overlap)} must be smaller'
137
                f' than patch size {tuple(patch_size)}'
138
            )
139
            raise ValueError(message)
140
        if np.any(patch_overlap % 2):
141
            message = (
142
                'Patch overlap must be a tuple of even integers,'
143
                f' not {tuple(patch_overlap)}'
144
            )
145
            raise ValueError(message)
146
147
    @staticmethod
148
    def _get_patches_locations(
149
            image_size: TypeTripletInt,
150
            patch_size: TypeTripletInt,
151
            patch_overlap: TypeTripletInt,
152
            ) -> np.ndarray:
153
        # Example with image_size 10, patch_size 5, overlap 2:
154
        # [0 1 2 3 4 5 6 7 8 9]
155
        # [0 0 0 0 0]
156
        #       [1 1 1 1 1]
157
        #           [2 2 2 2 2]
158
        # Locations:
159
        # [[0, 5],
160
        #  [3, 8],
161
        #  [5, 10]]
162
        indices = []
163
        zipped = zip(image_size, patch_size, patch_overlap)
164
        for im_size_dim, patch_size_dim, patch_overlap_dim in zipped:
165
            end = im_size_dim + 1 - patch_size_dim
166
            step = patch_size_dim - patch_overlap_dim
167
            indices_dim = list(range(0, end, step))
168
            if indices_dim[-1] != im_size_dim - patch_size_dim:
169
                indices_dim.append(im_size_dim - patch_size_dim)
170
            indices.append(indices_dim)
171
        indices_ini = np.array(np.meshgrid(*indices)).reshape(3, -1).T
172
        indices_ini = np.unique(indices_ini, axis=0)
173
        indices_fin = indices_ini + np.array(patch_size)
174
        locations = np.hstack((indices_ini, indices_fin))
175
        return np.array(sorted(locations.tolist()))
176