Passed
Pull Request — master (#175)
by Fernando
01:14
created

WeightedSampler.sample_probability_map()   A

Complexity

Conditions 2

Size

Total Lines 34
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 11
nop 1
dl 0
loc 34
rs 9.85
c 0
b 0
f 0
1
import copy
2
from typing import Union, Sequence, Generator, Tuple, Optional
3
4
import numpy as np
5
6
import torch
7
from torch.utils.data import IterableDataset
8
9
from ...torchio import DATA
10
from ...utils import to_tuple
11
from ..subject import Subject
12
13
14
15
class WeightedSampler(IterableDataset):
16
    r"""Extract random patches from a volume.
17
18
    Args:
19
        sample: Sample generated by a
20
            :py:class:`~torchio.data.dataset.ImagesDataset`, from which image
21
            patches will be extracted.
22
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
23
            of size :math:`d \times h \times w`.
24
            If a single number :math:`n` is provided, :math:`d = h = w = n`.
25
        probability_map_name: Name of the image in the sample that will be used
26
            as a probability map.
27
    """
28
    def __init__(
29
            self,
30
            sample: Subject,
31
            patch_size: Union[int, Sequence[int]],
32
            probability_map_name: Optional[str] = None,
33
            ):
34
        sample.check_consistent_shape()
35
        self.sample = sample
36
        patch_size = to_tuple(patch_size, length=3)
37
        self.patch_size = np.array(patch_size, dtype=np.uint16)
38
        if np.any(self.patch_size > sample.spatial_shape):
39
            message = (
40
                f'Patch size {tuple(self.patch_size)} cannot be'
41
                f' larger than image size {tuple(sample.spatial_shape)}'
42
            )
43
            raise ValueError(message)
44
        self.probability_map = self.process_probability_map(
45
            probability_map_name)
46
        self.cdf, self.sort_indices = self.get_cumulative_distribution_function(
47
            self.probability_map)
48
49
    def __iter__(self) -> Generator[Subject, None, None]:
50
        while True:
51
            yield self.extract_patch()
52
53
    def process_probability_map(self, probability_map_name):
54
        if probability_map_name in self.sample:
55
            data = self.sample[probability_map_name].data.copy()
56
        else:
57
            data = torch.ones(self.sample.shape)
58
        # Using float32 creates cdf with maximum very far from 1, e.g. 0.92!
59
        data = data[0].numpy().astype(np.float64)
60
        assert data.ndim == 3
61
        if np.any(data < 0):
62
            message = (
63
                'Negative values found'
64
                f' in probability map "{probability_map_name}"'
65
            )
66
            raise ValueError(message)
67
        if data.sum() == 0:  # although it should not be empty
68
            data += 1  # make uniform
69
        data /= data.sum()  # normalize probabilities
70
        self.clear_probability_borders(data, self.patch_size)
71
        return data
72
73
    @staticmethod
74
    def clear_probability_borders(probability_map, patch_size):
75
        # Set probability to 0 on voxels that wouldn't possibly be sampled given
76
        # the current patch size
77
        # We will arbitrarily define the center of an array with even length
78
        # using the // Python operator
79
        # For example, the center of an array (3, 4) will be on (1, 2)
80
        #
81
        #  . . . .        . . . .
82
        #  . . . .   ->   . . x .
83
        #  . . . .        . . . .
84
        #
85
        #  x x x x x x x      . . . . . . .
86
        #  x x x x x x x      . . x x x x .
87
        #  x x x x x x x  --> . . x x x x .
88
        #  x x x x x x x  --> . . x x x x .
89
        #  x x x x x x x      . . x x x x .
90
        #  x x x x x x x      . . . . . . .
91
        #
92
        # The dots represent removed probabilities, x mark possible locations
93
94
        crop_i, crop_j, crop_k = crop = np.array(patch_size) // 2
95
        probability_map[:crop_i, :, :] = 0
96
        probability_map[:, :crop_j, :] = 0
97
        probability_map[:, :, :crop_k] = 0
98
99
        # Subtract 1 to even numbers
100
        crop_i, crop_j, crop_k = [n - (n + 1) % 2 if n > 0 else n for n in crop]
101
        if crop_i:
102
            probability_map[-crop_i:, :, :] = 0
103
        if crop_j:
104
            probability_map[:, -crop_j:, :] = 0
105
        if crop_k:
106
            probability_map[:, :, -crop_k:] = 0
107
108
    def get_random_index_ini(self):
109
        center = self.sample_probability_map()
110
111
        # See self.clear_probability_borders
112
        index_ini = center - self.patch_size // 2
113
        assert np.all(index_ini >= 0)
114
        return index_ini
115
116
    @staticmethod
117
    def get_cumulative_distribution_function(probability_map):
118
        # Get the sorting indices to that we can invert the sorting later on
119
        flat_map = probability_map.flatten()
120
        flat_map_normalized = flat_map / flat_map.sum()
121
        sort_indices = np.argsort(flat_map_normalized)
122
        flat_map_normalized_sorted = flat_map[sort_indices]
123
        cdf = np.cumsum(flat_map_normalized_sorted)
124
        return cdf, sort_indices
125
126
    def sample_probability_map(self):
127
        """Inverse transform sampling.
128
129
        Example:
130
            >>> probability_map = np.array(
131
            ...    ((0,0,1,1,5,2,1,1,0),
132
            ...     (2,2,2,2,2,2,2,2,2)))
133
            >>> probability_map
134
            array([[0, 0, 1, 1, 5, 2, 1, 1, 0],
135
                   [2, 2, 2, 2, 2, 2, 2, 2, 2]])
136
            >>> histogram = np.zeros_like(probability_map)
137
            >>> for _ in range(100000):
138
            ...     histogram[sample_probability_map(probability_map)] += 1
139
            ...
140
            >>> histogram
141
            array([[    0,     0,  3479,  3478, 17121,  7023,  3355,  3378,     0],
142
                   [ 6808,  6804,  6942,  6809,  6946,  6988,  7002,  6826,  7041]])
143
144
        """
145
        # Get first value larger than random number
146
        random_number = torch.rand(1).item()
147
        # If probability map is float32, cdf.max() can be far from 1, e.g. 0.92
148
        if random_number > self.cdf.max():
149
            cdf_index = -1
150
        else:  # proceed as usual
151
            cdf_index = np.argmax(random_number < self.cdf)
152
153
        random_location_index = self.sort_indices[cdf_index]
154
        center = np.unravel_index(
155
            random_location_index,
156
            self.probability_map.shape
157
        )
158
        center = np.array(center).astype(int)
159
        return center
160
161
    def extract_patch(self) -> Subject:
162
        # TODO: replace with Crop transform
163
        index_ini = self.get_random_index_ini()
164
        cropped_sample = self.copy_and_crop(index_ini)
165
        return cropped_sample
166
167 View Code Duplication
    def copy_and_crop(self, index_ini: np.ndarray) -> dict:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
168
        index_fin = index_ini + self.patch_size
169
        cropped_sample = copy.deepcopy(self.sample)
170
        iterable = self.sample.get_images_dict(intensity_only=False).items()
171
        for image_name, image in iterable:
172
            cropped_sample[image_name] = copy.deepcopy(image)
173
            sample_image_dict = image
174
            cropped_image_dict = cropped_sample[image_name]
175
            cropped_image_dict[DATA] = self.crop(
176
                sample_image_dict[DATA], index_ini, index_fin)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
177
        # torch doesn't like uint16
178
        cropped_sample['index_ini'] = index_ini.astype(int)
179
        return cropped_sample
180
181
    @staticmethod
182
    def crop(
183
            data: Union[np.ndarray, torch.Tensor],
184
            index_ini: np.ndarray,
185
            index_fin: np.ndarray,
186
            ) -> Union[np.ndarray, torch.Tensor]:
187
        i_ini, j_ini, k_ini = index_ini
188
        i_fin, j_fin, k_fin = index_fin
189
        return data[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
190