Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

torchio.data.inference.grid_sampler   A

Complexity

Total Complexity 13

Size/Duplication

Total Lines 171
Duplicated Lines 9.94 %

Importance

Changes 0
Metric Value
eloc 120
dl 17
loc 171
rs 10
c 0
b 0
f 0
wmc 13

7 Methods

Rating   Name   Duplication   Size   Complexity  
A GridSampler.__init__() 0 13 1
A GridSampler.copy_and_crop() 17 17 2
A GridSampler.__getitem__() 0 8 1
A GridSampler.extract_patch() 0 12 1
A GridSampler.__len__() 0 2 1
A GridSampler._grid_spatial_coordinates() 0 39 3
A GridSampler._enumerate_step_points() 0 28 3

1 Function

Rating   Name   Duplication   Size   Complexity  
A crop() 0 8 1

How to fix   Duplicated Code   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

1
import copy
2
from typing import Tuple
3
4
import torch
5
import numpy as np
6
from torch.utils.data import Dataset
7
8
from ...utils import to_tuple
9
from ...torchio import LOCATION, TypeTuple, DATA
10
from ..subject import Subject
11
12
13
class GridSampler(Dataset):
14
    r"""Extract patches across a whole volume.
15
16
    Grid samplers are useful to perform inference using all patches from a
17
    volume. It is often used with a
18
    :py:class:`~torchio.data.GridAggregator`.
19
20
    Args:
21
        sample: Instance of :py:class:`~torchio.data.subject.Subject`
22
            from which patches will be extracted.
23
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
24
            of size :math:`d \times h \times w`.
25
            If a single number :math:`n` is provided,
26
            :math:`d = h = w = n`.
27
        patch_overlap: Tuple of integers :math:`(d_o, h_o, w_o)` specifying the
28
            overlap between patches for dense inference. If a single number
29
            :math:`n` is provided, :math:`d_o = h_o = w_o = n`.
30
31
    .. note:: Adapted from NiftyNet. See `this NiftyNet tutorial
32
        <https://niftynet.readthedocs.io/en/dev/window_sizes.html>`_ for more
33
        information.
34
    """
35
    def __init__(
36
            self,
37
            sample: Subject,
38
            patch_size: TypeTuple,
39
            patch_overlap: TypeTuple,
40
            ):
41
        self.sample = sample
42
        patch_size = to_tuple(patch_size, length=3)
43
        patch_overlap = to_tuple(patch_overlap, length=3)
44
        self.locations = self._grid_spatial_coordinates(
45
            self.sample.shape,
46
            patch_size,
47
            patch_overlap,
48
        )
49
50
    def __len__(self):
51
        return len(self.locations)
52
53
    def __getitem__(self, index):
54
        # Assume 3D
55
        location = self.locations[index]
56
        index_ini = location[:3]
57
        index_fin = location[3:]
58
        cropped_sample = self.extract_patch(self.sample, index_ini, index_fin)
59
        cropped_sample[LOCATION] = location
60
        return cropped_sample
61
62
    def extract_patch(
63
            self,
64
            sample: Subject,
65
            index_ini: Tuple[int, int, int],
66
            index_fin: Tuple[int, int, int],
67
            ) -> Subject:
68
        cropped_sample = self.copy_and_crop(
69
            sample,
70
            index_ini,
71
            index_fin,
72
        )
73
        return cropped_sample
74
75 View Code Duplication
    @staticmethod
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
76
    def copy_and_crop(
77
            sample: Subject,
78
            index_ini: np.ndarray,
79
            index_fin: np.ndarray,
80
            ) -> dict:
81
        cropped_sample = {}
82
        iterable = sample.get_images_dict(intensity_only=False).items()
83
        for image_name, image in iterable:
84
            cropped_sample[image_name] = copy.deepcopy(image)
85
            sample_image_dict = image
86
            cropped_image_dict = cropped_sample[image_name]
87
            cropped_image_dict[DATA] = crop(
88
                sample_image_dict[DATA], index_ini, index_fin)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
89
        # torch doesn't like uint16
90
        cropped_sample['index_ini'] = index_ini.astype(int)
91
        return cropped_sample
92
93
    @staticmethod
94
    def _grid_spatial_coordinates(
95
            volume_shape: Tuple[int, int, int],
96
            patch_shape: Tuple[int, int, int],
97
            border: Tuple[int, int, int],
98
            ) -> np.ndarray:
99
        volume_shape = np.array(volume_shape)
100
        patch_shape = np.array(patch_shape)
101
        border = np.array(border)
102
        grid_size = np.maximum(patch_shape - 2 * border, 0)
103
        num_dims = len(volume_shape)
104
105
        steps_along_each_dim = [
106
            GridSampler._enumerate_step_points(
107
                starting=0,
108
                ending=volume_shape[i],
109
                patch_shape=patch_shape[i],
110
                step_size=grid_size[i],
111
            )
112
            for i in range(num_dims)
113
        ]
114
        starting_coords = np.asanyarray(np.meshgrid(*steps_along_each_dim))
115
        starting_coords = starting_coords.reshape((num_dims, -1)).T
116
        n_locations = starting_coords.shape[0]
117
        spatial_coords = np.zeros((n_locations, num_dims * 2), dtype=np.int32)
118
        spatial_coords[:, :num_dims] = starting_coords
119
        for idx in range(num_dims):
120
            spatial_coords[:, num_dims + idx] = (
121
                starting_coords[:, idx]
122
                + patch_shape[idx]
123
            )
124
        max_coordinates = np.max(spatial_coords, axis=0)[num_dims:]
125
        if np.any(max_coordinates > volume_shape[:num_dims]):
126
            message = (
127
                f'Window size {tuple(patch_shape)}'
128
                f' is larger than volume {tuple(volume_shape)}'
129
            )
130
            raise ValueError(message)
131
        return spatial_coords
132
133
    @staticmethod
134
    def _enumerate_step_points(
135
            starting: Tuple[int, int, int],
136
            ending: Tuple[int, int, int],
137
            patch_shape: Tuple[int, int, int],
138
            step_size: Tuple[int, int, int],
139
            ) -> np.ndarray:
140
141
        starting = np.maximum(starting, 0)
142
        ending = np.maximum(ending, 0)
143
        patch_shape = np.maximum(patch_shape, 1)
144
        step_size = np.maximum(step_size, 1)
145
146
        starting = np.minimum(starting, ending)
147
        ending = np.maximum(starting, ending)
148
149
        sampling_point_set = []
150
        while (starting + patch_shape) <= ending:
151
            sampling_point_set.append(starting)
152
            starting = starting + step_size
153
        additional_last_point = ending - patch_shape
154
        sampling_point_set.append(np.maximum(additional_last_point, 0))
155
        sampling_point_set = np.unique(sampling_point_set, axis=0)
156
        if len(sampling_point_set) == 2:
157
            mean = np.round(np.mean(sampling_point_set, axis=0))
158
            sampling_point_set = np.append(sampling_point_set, mean)
159
        _, uniq_idx = np.unique(sampling_point_set, return_index=True)
160
        return sampling_point_set[np.sort(uniq_idx)]
161
162
163
def crop(
164
        image: torch.Tensor,
165
        index_ini: np.ndarray,
166
        index_fin: np.ndarray,
167
        ) -> torch.Tensor:
168
    i_ini, j_ini, k_ini = index_ini
169
    i_fin, j_fin, k_fin = index_fin
170
    return image[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
171