Passed
Pull Request — master (#182)
by Fernando
01:14
created

GridSampler.get_patches_locations()   A

Complexity

Conditions 3

Size

Total Lines 20
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 19
nop 3
dl 0
loc 20
rs 9.45
c 0
b 0
f 0
1
import copy
2
from typing import Tuple
3
4
import torch
5
import numpy as np
6
from torch.utils.data import Dataset
7
8
from ...utils import to_tuple
9
from ...torchio import LOCATION, TypeTuple, DATA, TypeTripletInt
10
from ..subject import Subject
11
12
13
class GridSampler(Dataset):
14
    r"""Extract patches across a whole volume.
15
16
    Grid samplers are useful to perform inference using all patches from a
17
    volume. It is often used with a :py:class:`~torchio.data.GridAggregator`.
18
19
    Args:
20
        sample: Instance of :py:class:`~torchio.data.subject.Subject`
21
            from which patches will be extracted.
22
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
23
            of size :math:`d \times h \times w`.
24
            If a single number :math:`n` is provided,
25
            :math:`d = h = w = n`.
26
        patch_overlap: Tuple of integers :math:`(d_o, h_o, w_o)` specifying the
27
            overlap between patches for dense inference. If a single number
28
            :math:`n` is provided, :math:`d_o = h_o = w_o = n`.
29
30
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
31
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
32
        information.
33
    """
34
    def __init__(
35
            self,
36
            sample: Subject,
37
            patch_size: TypeTuple,
38
            patch_overlap: TypeTuple,
39
            ):
40
        self.sample = sample
41
        patch_size = to_tuple(patch_size, length=3)
42
        patch_overlap = to_tuple(patch_overlap, length=3)
43
        sizes = self.sample.spatial_shape, patch_size, patch_overlap
44
        self.parse_sizes(*sizes)
45
        self.locations = self.get_patches_locations(*sizes)
46
47
    def __len__(self):
48
        return len(self.locations)
49
50
    def __getitem__(self, index):
51
        # Assume 3D
52
        location = self.locations[index]
53
        index_ini = location[:3]
54
        index_fin = location[3:]
55
        cropped_sample = self.extract_patch(self.sample, index_ini, index_fin)
56
        cropped_sample[LOCATION] = location
57
        return cropped_sample
58
59
    @staticmethod
60
    def parse_sizes(
61
            image_size: TypeTripletInt,
62
            patch_size: TypeTripletInt,
63
            patch_overlap: TypeTripletInt,
64
            ) -> None:
65
        image_size = np.array(image_size)
66
        patch_size = np.array(patch_size)
67
        patch_overlap = np.array(patch_overlap)
68
        if np.any(patch_size > image_size):
69
            message = (
70
                f'Patch size {tuple(patch_size)} cannot be'
71
                f' larger than image size {tuple(image_size)}'
72
            )
73
            raise ValueError(message)
74
        if np.any(patch_overlap >= patch_size):
75
            message = (
76
                f'Patch overlap {tuple(patch_overlap)} must be smaller'
77
                f' larger than patch size {tuple(image_size)}'
78
            )
79
            raise ValueError(message)
80
81
    def extract_patch(
82
            self,
83
            sample: Subject,
84
            index_ini: TypeTripletInt,
85
            index_fin: TypeTripletInt,
86
            ) -> Subject:
87
        cropped_sample = self.copy_and_crop(
88
            sample,
89
            index_ini,
90
            index_fin,
91
        )
92
        return cropped_sample
93
94
    @staticmethod
95
    def copy_and_crop(
96
            sample: Subject,
97
            index_ini: np.ndarray,
98
            index_fin: np.ndarray,
99
            ) -> dict:
100
        cropped_sample = {}
101
        iterable = sample.get_images_dict(intensity_only=False).items()
102
        for image_name, image in iterable:
103
            cropped_sample[image_name] = copy.deepcopy(image)
104
            sample_image_dict = image
105
            cropped_image_dict = cropped_sample[image_name]
106
            cropped_image_dict[DATA] = crop(
107
                sample_image_dict[DATA], index_ini, index_fin)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
108
        # torch doesn't like uint16
109
        cropped_sample['index_ini'] = index_ini.astype(int)
110
        return cropped_sample
111
112
    @staticmethod
113
    def get_patches_locations(
114
            image_size: TypeTripletInt,
115
            patch_size: TypeTripletInt,
116
            patch_overlap: TypeTripletInt,
117
            ) -> np.ndarray:
118
        indices = []
119
        zipped = zip(image_size, patch_size, patch_overlap)
120
        for im_size_dim, patch_size_dim, patch_overlap_dim in zipped:
121
            end = im_size_dim + 1 - patch_size_dim
122
            step = patch_size_dim - patch_overlap_dim
123
            indices_dim = list(range(0, end, step))
124
            if im_size_dim % step:
125
                indices_dim.append(im_size_dim - patch_size_dim)
126
            indices.append(indices_dim)
127
        indices_ini = np.array(np.meshgrid(*indices)).reshape(3, -1).T
128
        indices_ini = np.unique(indices_ini, axis=0)
129
        indices_fin = indices_ini + np.array(patch_size)
130
        locations = np.hstack((indices_ini, indices_fin))
131
        return np.array(sorted(locations.tolist()))
132
133
134
def crop(
135
        image: torch.Tensor,
136
        index_ini: np.ndarray,
137
        index_fin: np.ndarray,
138
        ) -> torch.Tensor:
139
    i_ini, j_ini, k_ini = index_ini
140
    i_fin, j_fin, k_fin = index_fin
141
    return image[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
142