Passed
Pull Request — master (#175)
by Fernando
01:16
created

WeightedSampler.get_cumulative_distribution_function()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 8
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
import copy
2
from typing import Union, Sequence, Generator, Tuple, Optional
3
4
import numpy as np
5
6
import torch
7
from torch.utils.data import IterableDataset
8
9
from ...torchio import DATA
10
from ...utils import to_tuple
11
from ..subject import Subject
12
13
14
15
class WeightedSampler(IterableDataset):
16
    r"""Extract random patches from a volume given a probability map.
17
18
    The probability of sampling a patch centered on a specific voxel is the
19
    value of that voxel in the probability map. The probabilities need not be
20
    normalized. For example, voxels can have values 0, 1 and 5. Voxels with
21
    value 0 will never be at the center of a patch. Voxels with value 5 will
22
    have 5 times more chance of being at the center of a patch that voxels
23
    with a value of 1.
24
25
    Args:
26
        sample: Sample generated by a
27
            :py:class:`~torchio.data.dataset.ImagesDataset`, from which image
28
            patches will be extracted.
29
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
30
            of size :math:`d \times h \times w`.
31
            If a single number :math:`n` is provided, :math:`d = h = w = n`.
32
        probability_map: Name of the image in the sample that will be used
33
            as a probability map. If ``None``, all voxels have the same
34
            probability of being at the patch center.
35
36
    .. note:: The index of the center of a patch with even size :math:`s` is
37
        arbitrarily set to :math:`s/2`. This is an implementation detail that
38
        will typically not make any difference in practice.
39
40
    .. note:: Values of the probability map near the border will be set to 0 as
41
        the center of the patch cannot be at the border (unless the patch has
42
        size 1 or 2 along that axis).
43
44
    """
45
    def __init__(
46
            self,
47
            sample: Subject,
48
            patch_size: Union[int, Sequence[int]],
49
            probability_map: Optional[str] = None,
50
            ):
51
        sample.check_consistent_shape()
52
        self.sample = sample
53
        patch_size = to_tuple(patch_size, length=3)
54
        self.patch_size = np.array(patch_size, dtype=np.uint16)
55
        if np.any(self.patch_size > sample.spatial_shape):
56
            message = (
57
                f'Patch size {tuple(self.patch_size)} cannot be'
58
                f' larger than image size {tuple(sample.spatial_shape)}'
59
            )
60
            raise ValueError(message)
61
        self.probability_map = self.process_probability_map(probability_map)
62
        self.cdf, self.sort_indices = self.get_cumulative_distribution_function(
63
            self.probability_map)
64
65
    def __iter__(self) -> Generator[Subject, None, None]:
66
        while True:
67
            yield self.extract_patch()
68
69
    def process_probability_map(self, probability_map_name):
70
        if probability_map_name in self.sample:
71
            data = self.sample[probability_map_name].data.copy()
72
        else:
73
            data = torch.ones(self.sample.shape)
74
        # Using float32 creates cdf with maximum very far from 1, e.g. 0.92!
75
        data = data[0].numpy().astype(np.float64)
76
        assert data.ndim == 3
77
        if np.any(data < 0):
78
            message = (
79
                'Negative values found'
80
                f' in probability map "{probability_map_name}"'
81
            )
82
            raise ValueError(message)
83
        if data.sum() == 0:  # although it should not be empty
84
            data += 1  # make uniform
85
        data /= data.sum()  # normalize probabilities
86
        self.clear_probability_borders(data, self.patch_size)
87
        return data
88
89
    @staticmethod
90
    def clear_probability_borders(probability_map, patch_size):
91
        # Set probability to 0 on voxels that wouldn't possibly be sampled given
92
        # the current patch size
93
        # We will arbitrarily define the center of an array with even length
94
        # using the // Python operator
95
        # For example, the center of an array (3, 4) will be on (1, 2)
96
        #
97
        #  . . . .        . . . .
98
        #  . . . .   ->   . . x .
99
        #  . . . .        . . . .
100
        #
101
        #  x x x x x x x      . . . . . . .
102
        #  x x x x x x x      . . x x x x .
103
        #  x x x x x x x  --> . . x x x x .
104
        #  x x x x x x x  --> . . x x x x .
105
        #  x x x x x x x      . . x x x x .
106
        #  x x x x x x x      . . . . . . .
107
        #
108
        # The dots represent removed probabilities, x mark possible locations
109
110
        crop_i, crop_j, crop_k = crop = np.array(patch_size) // 2
111
        probability_map[:crop_i, :, :] = 0
112
        probability_map[:, :crop_j, :] = 0
113
        probability_map[:, :, :crop_k] = 0
114
115
        # Subtract 1 to even numbers
116
        crop_i, crop_j, crop_k = [n - (n + 1) % 2 if n > 0 else n for n in crop]
117
        if crop_i:
118
            probability_map[-crop_i:, :, :] = 0
119
        if crop_j:
120
            probability_map[:, -crop_j:, :] = 0
121
        if crop_k:
122
            probability_map[:, :, -crop_k:] = 0
123
124
    def get_random_index_ini(self):
125
        center = self.sample_probability_map()
126
127
        # See self.clear_probability_borders
128
        index_ini = center - self.patch_size // 2
129
        assert np.all(index_ini >= 0)
130
        return index_ini
131
132
    @staticmethod
133
    def get_cumulative_distribution_function(probability_map):
134
        # Get the sorting indices to that we can invert the sorting later on
135
        flat_map = probability_map.flatten()
136
        flat_map_normalized = flat_map / flat_map.sum()
137
        sort_indices = np.argsort(flat_map_normalized)
138
        flat_map_normalized_sorted = flat_map[sort_indices]
139
        cdf = np.cumsum(flat_map_normalized_sorted)
140
        return cdf, sort_indices
141
142
    def sample_probability_map(self):
143
        """Inverse transform sampling.
144
145
        Example:
146
            >>> probability_map = np.array(
147
            ...    ((0,0,1,1,5,2,1,1,0),
148
            ...     (2,2,2,2,2,2,2,2,2)))
149
            >>> probability_map
150
            array([[0, 0, 1, 1, 5, 2, 1, 1, 0],
151
                   [2, 2, 2, 2, 2, 2, 2, 2, 2]])
152
            >>> histogram = np.zeros_like(probability_map)
153
            >>> for _ in range(100000):
154
            ...     histogram[sample_probability_map(probability_map)] += 1
155
            ...
156
            >>> histogram
157
            array([[    0,     0,  3479,  3478, 17121,  7023,  3355,  3378,     0],
158
                   [ 6808,  6804,  6942,  6809,  6946,  6988,  7002,  6826,  7041]])
159
160
        """
161
        # Get first value larger than random number
162
        random_number = torch.rand(1).item()
163
        # If probability map is float32, cdf.max() can be far from 1, e.g. 0.92
164
        if random_number > self.cdf.max():
165
            cdf_index = -1
166
        else:  # proceed as usual
167
            cdf_index = np.argmax(random_number < self.cdf)
168
169
        random_location_index = self.sort_indices[cdf_index]
170
        center = np.unravel_index(
171
            random_location_index,
172
            self.probability_map.shape
173
        )
174
        center = np.array(center).astype(int)
175
        return center
176
177
    def extract_patch(self) -> Subject:
178
        # TODO: replace with Crop transform
179
        index_ini = self.get_random_index_ini()
180
        cropped_sample = self.copy_and_crop(index_ini)
181
        return cropped_sample
182
183 View Code Duplication
    def copy_and_crop(self, index_ini: np.ndarray) -> dict:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
184
        index_fin = index_ini + self.patch_size
185
        cropped_sample = copy.deepcopy(self.sample)
186
        iterable = self.sample.get_images_dict(intensity_only=False).items()
187
        for image_name, image in iterable:
188
            cropped_sample[image_name] = copy.deepcopy(image)
189
            sample_image_dict = image
190
            cropped_image_dict = cropped_sample[image_name]
191
            cropped_image_dict[DATA] = self.crop(
192
                sample_image_dict[DATA], index_ini, index_fin)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
193
        # torch doesn't like uint16
194
        cropped_sample['index_ini'] = index_ini.astype(int)
195
        return cropped_sample
196
197
    @staticmethod
198
    def crop(
199
            data: Union[np.ndarray, torch.Tensor],
200
            index_ini: np.ndarray,
201
            index_fin: np.ndarray,
202
            ) -> Union[np.ndarray, torch.Tensor]:
203
        i_ini, j_ini, k_ini = index_ini
204
        i_fin, j_fin, k_fin = index_fin
205
        return data[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
206