Passed
Pull Request — master (#182)
by Fernando
01:13
created

GridSampler.parse_sizes()   A

Complexity

Conditions 3

Size

Total Lines 21
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 16
nop 3
dl 0
loc 21
rs 9.6
c 0
b 0
f 0
1
import copy
2
3
import torch
4
import numpy as np
5
from torch.utils.data import Dataset
6
7
from ..sampler.sampler import PatchSampler
8
from ...utils import to_tuple
9
from ...torchio import LOCATION, TypeTuple, DATA, TypeTripletInt
10
from ..subject import Subject
11
12
13
class GridSampler(PatchSampler, Dataset):
14
    r"""Extract patches across a whole volume.
15
16
    Grid samplers are useful to perform inference using all patches from a
17
    volume. It is often used with a :py:class:`~torchio.data.GridAggregator`.
18
19
    Args:
20
        sample: Instance of :py:class:`~torchio.data.subject.Subject`
21
            from which patches will be extracted.
22
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
23
            of size :math:`d \times h \times w`.
24
            If a single number :math:`n` is provided,
25
            :math:`d = h = w = n`.
26
        patch_overlap: Tuple of integers :math:`(d_o, h_o, w_o)` specifying the
27
            overlap between patches for dense inference. If a single number
28
            :math:`n` is provided, :math:`d_o = h_o = w_o = n`.
29
30
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
31
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
32
        information.
33
    """
34
    def __init__(
35
            self,
36
            sample: Subject,
37
            patch_size: TypeTuple,
38
            patch_overlap: TypeTuple,
39
            ):
40
        self.sample = sample
41
        PatchSampler.__init__(self, patch_size)
42
        patch_size = to_tuple(patch_size, length=3)
43
        patch_overlap = to_tuple(patch_overlap, length=3)
44
        sizes = self.sample.spatial_shape, patch_size, patch_overlap
45
        self.parse_sizes(*sizes)
46
        self.locations = self.get_patches_locations(*sizes)
47
48
    def __len__(self):
49
        return len(self.locations)
50
51
    def __getitem__(self, index):
52
        # Assume 3D
53
        location = self.locations[index]
54
        index_ini = location[:3]
55
        index_fin = location[3:]
56
        cropped_sample = self.extract_patch(self.sample, index_ini, index_fin)
57
        cropped_sample[LOCATION] = location
58
        return cropped_sample
59
60
    @staticmethod
61
    def parse_sizes(
62
            image_size: TypeTripletInt,
63
            patch_size: TypeTripletInt,
64
            patch_overlap: TypeTripletInt,
65
            ) -> None:
66
        image_size = np.array(image_size)
67
        patch_size = np.array(patch_size)
68
        patch_overlap = np.array(patch_overlap)
69
        if np.any(patch_size > image_size):
70
            message = (
71
                f'Patch size {tuple(patch_size)} cannot be'
72
                f' larger than image size {tuple(image_size)}'
73
            )
74
            raise ValueError(message)
75
        if np.any(patch_overlap >= patch_size):
76
            message = (
77
                f'Patch overlap {tuple(patch_overlap)} must be smaller'
78
                f' larger than patch size {tuple(image_size)}'
79
            )
80
            raise ValueError(message)
81
82
    def extract_patch(
83
            self,
84
            sample: Subject,
85
            index_ini: TypeTripletInt,
86
            index_fin: TypeTripletInt,
87
            ) -> Subject:
88
        crop = self.get_crop_transform(
89
            sample.spatial_shape,
90
            index_ini,
91
            index_fin - index_ini,
92
        )
93
        cropped_sample = crop(sample)
94
        return cropped_sample
95
96
    @staticmethod
97
    def get_patches_locations(
98
            image_size: TypeTripletInt,
99
            patch_size: TypeTripletInt,
100
            patch_overlap: TypeTripletInt,
101
            ) -> np.ndarray:
102
        indices = []
103
        zipped = zip(image_size, patch_size, patch_overlap)
104
        for im_size_dim, patch_size_dim, patch_overlap_dim in zipped:
105
            end = im_size_dim + 1 - patch_size_dim
106
            step = patch_size_dim - patch_overlap_dim
107
            indices_dim = list(range(0, end, step))
108
            if im_size_dim % step:
109
                indices_dim.append(im_size_dim - patch_size_dim)
110
            indices.append(indices_dim)
111
        indices_ini = np.array(np.meshgrid(*indices)).reshape(3, -1).T
112
        indices_ini = np.unique(indices_ini, axis=0)
113
        indices_fin = indices_ini + np.array(patch_size)
114
        locations = np.hstack((indices_ini, indices_fin))
115
        return np.array(sorted(locations.tolist()))
116