Passed
Pull Request — master (#175)
by Fernando
01:01
created

PatchSampler.__init__()   A

Complexity

Conditions 2

Size

Total Lines 9
Code Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 7
nop 2
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
from typing import Union, Sequence, Tuple
2
3
import torch
4
import numpy as np
5
6
from ...utils import to_tuple
7
8
9
class PatchSampler:
10
    r"""Randomly extract patches from a volume.
11
12
    Args:
13
        sample: Sample generated by a
14
            :py:class:`~torchio.data.dataset.ImagesDataset`, from which image
15
            patches will be extracted.
16
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
17
            of size :math:`d \times h \times w`.
18
            If a single number :math:`n` is provided, :math:`d = h = w = n`.
19
    """
20
    def __init__(self, patch_size: Union[int, Sequence[int]]):
21
        patch_size = np.array(to_tuple(patch_size, length=3))
22
        if np.any(patch_size < 1):
23
            message = (
24
                'Patch dimensions must be positive integers,'
25
                f' not {patch_size}'
26
            )
27
            raise ValueError(message)
28
        self.patch_size = patch_size.astype(np.uint16)
29
30
    def __call__(self):
31
        raise NotImplementedError
32
33
    def get_probability_map(self):
34
        raise NotImplementedError
35
36
    def extract_patch(self):
37
        raise NotImplementedError
38
39
40
41
def crop(
42
        image: Union[np.ndarray, torch.Tensor],
43
        index_ini: np.ndarray,
44
        index_fin: np.ndarray,
45
        ) -> Union[np.ndarray, torch.Tensor]:
46
    i_ini, j_ini, k_ini = index_ini
47
    i_fin, j_fin, k_fin = index_fin
48
    return image[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
49
50
51
def get_random_indices_from_shape(
52
        shape: Tuple[int, int, int],
53
        patch_size: Tuple[int, int, int],
54
        ) -> Tuple[np.ndarray, np.ndarray]:
55
    shape_array = np.array(shape)
56
    patch_size_array = np.array(patch_size)
57
    max_index_ini = shape_array - patch_size_array
58
    if (max_index_ini < 0).any():
59
        message = (
60
            f'Patch size {patch_size} must not be'
61
            f' larger than image size {shape}'
62
        )
63
        raise ValueError(message)
64
    max_index_ini = max_index_ini.astype(np.uint16)
65
    coordinates = []
66
    for max_coordinate in max_index_ini.tolist():
67
        if max_coordinate == 0:
68
            coordinate = 0
69
        else:
70
            coordinate = torch.randint(max_coordinate, size=(1,)).item()
71
        coordinates.append(coordinate)
72
    index_ini = np.array(coordinates, np.uint16)
73
    index_fin = index_ini + patch_size_array
74
    return index_ini, index_fin
75