Passed
Pull Request — master (#175)
by Fernando
01:24
created

torchio.data.sampler.weighted   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 254
Duplicated Lines 5.12 %

Importance

Changes 0
Metric Value
eloc 122
dl 13
loc 254
rs 10
c 0
b 0
f 0
wmc 23

11 Methods

Rating   Name   Duplication   Size   Complexity  
A WeightedSampler.__init__() 0 9 1
A WeightedSampler.sample_probability_map() 0 39 2
A WeightedSampler.process_probability_map() 0 10 2
A WeightedSampler.__call__() 0 18 5
A WeightedSampler.get_cumulative_distribution_function() 0 26 1
A WeightedSampler.get_probability_map() 0 16 3
A WeightedSampler.get_random_index_ini() 0 8 1
A WeightedSampler.crop() 0 10 1
A WeightedSampler.clear_probability_borders() 0 36 4
A WeightedSampler.copy_and_crop() 13 13 2
A WeightedSampler.extract_patch() 0 6 1

How to fix   Duplicated Code   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

1
import copy
2
from typing import Union, Sequence, Generator, Optional
3
4
import numpy as np
5
6
import torch
7
8
from ...torchio import DATA
9
from ..subject import Subject
10
from .sampler import PatchSampler
11
12
13
14
class WeightedSampler(PatchSampler):
15
    r"""Randomly extract patches from a volume given a probability map.
16
17
    The probability of sampling a patch centered on a specific voxel is the
18
    value of that voxel in the probability map. The probabilities need not be
19
    normalized. For example, voxels can have values 0, 1 and 5. Voxels with
20
    value 0 will never be at the center of a patch. Voxels with value 5 will
21
    have 5 times more chance of being at the center of a patch that voxels
22
    with a value of 1.
23
24
    Args:
25
        sample: Sample generated by a
26
            :py:class:`~torchio.data.dataset.ImagesDataset`, from which image
27
            patches will be extracted.
28
        patch_size: See :py:class:`~torchio.data.PatchSampler`.
29
        probability_map: Name of the image in the sample that will be used
30
            as a probability map.
31
32
    Example:
33
        >>> import torchio
34
        >>> subject = torchio.Subject(
35
        ...     t1=torchio.Image('t1_mri.nii.gz', type=torchio.INTENSITY),
36
        ...     sampling_map=torchio.Image('sampling.nii.gz', type=torchio.SAMPLING_MAP),
37
        ... )
38
        >>> sample = torchio.ImagesDataset([subject])[0]
39
        >>> patch_size = 64
40
        >>> sampler = torchio.data.WeightedSampler(patch_size, probability_map='sampling_map')
41
        >>> for patch in sampler(sample):
42
        ...     print(patch['index_ini'])
43
44
    .. note:: The index of the center of a patch with even size :math:`s` is
45
        arbitrarily set to :math:`s/2`. This is an implementation detail that
46
        will typically not make any difference in practice.
47
48
    .. note:: Values of the probability map near the border will be set to 0 as
49
        the center of the patch cannot be at the border (unless the patch has
50
        size 1 or 2 along that axis).
51
52
    """
53
    def __init__(
54
            self,
55
            patch_size: Union[int, Sequence[int]],
56
            probability_map: Optional[str] = None,
57
            ):
58
        super().__init__(patch_size)
59
        self.probability_map_name = probability_map
60
        self.cdf = None
61
        self.sort_indices = None
62
63
    def __call__(self, sample, num_patches=None):
64
        sample.check_consistent_shape()
65
        if np.any(self.patch_size > sample.spatial_shape):
66
            message = (
67
                f'Patch size {tuple(self.patch_size)} cannot be'
68
                f' larger than image size {tuple(sample.spatial_shape)}'
69
            )
70
            raise RuntimeError(message)
71
        probability_map = self.get_probability_map(sample)
72
        probability_map = self.process_probability_map(probability_map)
73
        cdf, sort_indices = self.get_cumulative_distribution_function(
74
            probability_map)
75
76
        patches_left = num_patches if num_patches is not None else True
77
        while patches_left:
78
            yield self.extract_patch(sample, probability_map, cdf, sort_indices)
79
            if num_patches is not None:
80
                patches_left -= 1
81
82
    def get_probability_map(self, sample):
83
        if self.probability_map_name in sample:
84
            data = sample[self.probability_map_name].data
85
        else:
86
            message = (
87
                f'Image "{self.probability_map_name}"'
88
                f' not found in subject sample: {sample}'
89
            )
90
            raise KeyError(message)
91
        if torch.any(data < 0):
92
            message = (
93
                'Negative values found'
94
                f' in probability map "{self.probability_map_name}"'
95
            )
96
            raise ValueError(message)
97
        return data
98
99
    def process_probability_map(self, probability_map):
100
        # Using float32 can create cdf with maximum very far from 1, e.g. 0.92!
101
        data = probability_map[0].numpy().astype(np.float64)
102
        assert data.ndim == 3
103
        if data.sum() == 0:  # although it should not be empty
104
            data += 1  # make uniform
105
        data /= data.sum()  # normalize probabilities
106
        self.clear_probability_borders(data, self.patch_size)
107
        assert data.sum() > 0
108
        return data
109
110
    @staticmethod
111
    def clear_probability_borders(probability_map, patch_size):
112
        # Set probability to 0 on voxels that wouldn't possibly be sampled given
113
        # the current patch size
114
        # We will arbitrarily define the center of an array with even length
115
        # using the // Python operator
116
        # For example, the center of an array (3, 4) will be on (1, 2)
117
        #
118
        #  . . . .        . . . .
119
        #  . . . .   ->   . . x .
120
        #  . . . .        . . . .
121
        #
122
        #  x x x x x x x      . . . . . . .
123
        #  x x x x x x x      . . x x x x .
124
        #  x x x x x x x  --> . . x x x x .
125
        #  x x x x x x x  --> . . x x x x .
126
        #  x x x x x x x      . . x x x x .
127
        #  x x x x x x x      . . . . . . .
128
        #
129
        # The dots represent removed probabilities, x mark possible locations
130
        crop_ini = patch_size // 2
131
        crop_fin = (patch_size - 1) // 2
132
        crop_i, crop_j, crop_k = crop_ini
133
        probability_map[:crop_i, :, :] = 0
134
        probability_map[:, :crop_j, :] = 0
135
        probability_map[:, :, :crop_k] = 0
136
137
        # The call tolist() is very important. Using np.uint16 as negative index
138
        # will not work because e.g. -np.uint16(2) == 65534
139
        crop_i, crop_j, crop_k = crop_fin.tolist()
140
        if crop_i:
141
            probability_map[-crop_i:, :, :] = 0
142
        if crop_j:
143
            probability_map[:, -crop_j:, :] = 0
144
        if crop_k:
145
            probability_map[:, :, -crop_k:] = 0
146
147
    @staticmethod
148
    def get_cumulative_distribution_function(probability_map):
149
        """Return the CDF of a probability map.
150
151
        The cumulative distribution function (CDF) is computed as follows:
152
153
        1. Flatten probability map
154
        2. Normalize it
155
        3. Compute sorting indices
156
        4. Sort flattened map
157
        5. Compute cumulative sum
158
159
        For example,
160
        if the probability map is [0, 0, 1, 2, 5, 1, 1, 0],
161
        the normalized version is [0.0, 0.0, 0.1, 0.2, 0.5, 0.1, 0.1, 0.0],
162
        the sorting indices are [0, 1, 7, 2, 5, 6, 3, 4],
163
        the sorted map is [0.0, 0.0, 0.0, 0.1, 0.1, 0.1, 0.2, 0.5],
164
        and the CDF is [0.0, 0.0, 0.0, 0.1, 0.2, 0.3, 0.5, 1.0].
165
        """
166
        flat_map = probability_map.flatten()
167
        flat_map_normalized = flat_map / flat_map.sum()
168
        # Get the sorting indices to that we can invert the sorting later on
169
        sort_indices = np.argsort(flat_map_normalized)
170
        flat_map_normalized_sorted = flat_map_normalized[sort_indices]
171
        cdf = np.cumsum(flat_map_normalized_sorted)
172
        return cdf, sort_indices
173
174
    def extract_patch(self, sample, probability_map, cdf, sort_indices) -> Subject:
175
        # TODO: replace with Crop transform
176
        index_ini = self.get_random_index_ini(probability_map, cdf, sort_indices)
177
        cropped_sample = self.copy_and_crop(sample, index_ini)
178
        assert cropped_sample.spatial_shape == tuple(self.patch_size)
179
        return cropped_sample
180
181
    def get_random_index_ini(self, probability_map, cdf, sort_indices):
182
        center = self.sample_probability_map(probability_map, cdf, sort_indices)
183
        assert np.all(center >= 0)
184
185
        # See self.clear_probability_borders
186
        index_ini = center - self.patch_size // 2
187
        assert np.all(index_ini >= 0)
188
        return index_ini
189
190
    def sample_probability_map(self, probability_map, cdf, sort_indices):
191
        """Inverse transform sampling.
192
193
        Example:
194
            >>> probability_map = np.array(
195
            ...    ((0,0,1,1,5,2,1,1,0),
196
            ...     (2,2,2,2,2,2,2,2,2)))
197
            >>> probability_map
198
            array([[0, 0, 1, 1, 5, 2, 1, 1, 0],
199
                   [2, 2, 2, 2, 2, 2, 2, 2, 2]])
200
            >>> histogram = np.zeros_like(probability_map)
201
            >>> for _ in range(100000):
202
            ...     histogram[sample_probability_map(probability_map)] += 1
203
            ...
204
            >>> histogram
205
            array([[    0,     0,  3479,  3478, 17121,  7023,  3355,  3378,     0],
206
                   [ 6808,  6804,  6942,  6809,  6946,  6988,  7002,  6826,  7041]])
207
208
        """
209
        # Get first value larger than random number
210
        random_number = torch.rand(1).item()
211
        # If probability map is float32, cdf.max() can be far from 1, e.g. 0.92
212
        if random_number > cdf.max():
213
            cdf_index = -1
214
        else:  # proceed as usual
215
            cdf_index = np.argmax(random_number < cdf)
216
217
        random_location_index = sort_indices[cdf_index]
218
        center = np.unravel_index(
219
            random_location_index,
220
            probability_map.shape
221
        )
222
223
        i, j, k = center
224
        probability = probability_map[i, j, k]
225
        assert probability > 0
226
227
        center = np.array(center).astype(int)
228
        return center
229
230 View Code Duplication
    def copy_and_crop(self, sample, index_ini: np.ndarray) -> dict:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
231
        index_fin = index_ini + self.patch_size
232
        cropped_sample = copy.deepcopy(sample)
233
        iterable = sample.get_images_dict(intensity_only=False).items()
234
        for image_name, image in iterable:
235
            cropped_sample[image_name] = copy.deepcopy(image)
236
            sample_image_dict = image
237
            cropped_image_dict = cropped_sample[image_name]
238
            cropped_image_dict[DATA] = self.crop(
239
                sample_image_dict[DATA], index_ini, index_fin)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
240
        # torch doesn't like uint16
241
        cropped_sample['index_ini'] = index_ini.astype(int)
242
        return cropped_sample
243
244
    @staticmethod
245
    def crop(
246
            data: Union[np.ndarray, torch.Tensor],
247
            index_ini: np.ndarray,
248
            index_fin: np.ndarray,
249
            ) -> Union[np.ndarray, torch.Tensor]:
250
        i_ini, j_ini, k_ini = index_ini
251
        assert np.all(np.array(index_fin) <= np.array(data.shape[1:]))
252
        i_fin, j_fin, k_fin = index_fin
253
        return data[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
254