Passed
Pull Request — master (#175)
by Fernando
01:24
created

  A

Complexity

Conditions 2

Size

Total Lines 17
Code Lines 15

Duplication

Lines 17
Ratio 100 %

Importance

Changes 0
Metric Value
cc 2
eloc 15
nop 3
dl 17
loc 17
rs 9.65
c 0
b 0
f 0
1
from itertools import islice
2
from typing import Union, Sequence, Tuple
3
4
import torch
5
import numpy as np
6
7
from ...utils import to_tuple
8
9
10
class PatchSampler:
11
    r"""Base class for TorchIO samplers.
12
13
    Args:
14
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
15
            of size :math:`d \times h \times w`.
16
            If a single number :math:`n` is provided, :math:`d = h = w = n`.
17
    """
18
    def __init__(self, patch_size: Union[int, Sequence[int]]):
19
        patch_size = np.array(to_tuple(patch_size, length=3))
20
        if np.any(patch_size < 1):
21
            message = (
22
                'Patch dimensions must be positive integers,'
23
                f' not {patch_size}'
24
            )
25
            raise ValueError(message)
26
        self.patch_size = patch_size.astype(np.uint16)
27
28
    def __call__(self):
29
        raise NotImplementedError
30
31
    def get_probability_map(self):
32
        raise NotImplementedError
33
34
    def extract_patch(self):
35
        raise NotImplementedError
36
37
38
39
def crop(
40
        image: Union[np.ndarray, torch.Tensor],
41
        index_ini: np.ndarray,
42
        index_fin: np.ndarray,
43
        ) -> Union[np.ndarray, torch.Tensor]:
44
    i_ini, j_ini, k_ini = index_ini
45
    i_fin, j_fin, k_fin = index_fin
46
    return image[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
47
48
49
def get_random_indices_from_shape(
50
        shape: Tuple[int, int, int],
51
        patch_size: Tuple[int, int, int],
52
        ) -> Tuple[np.ndarray, np.ndarray]:
53
    shape_array = np.array(shape)
54
    patch_size_array = np.array(patch_size)
55
    max_index_ini = shape_array - patch_size_array
56
    if (max_index_ini < 0).any():
57
        message = (
58
            f'Patch size {patch_size} must not be'
59
            f' larger than image size {shape}'
60
        )
61
        raise ValueError(message)
62
    max_index_ini = max_index_ini.astype(np.uint16)
63
    coordinates = []
64
    for max_coordinate in max_index_ini.tolist():
65
        if max_coordinate == 0:
66
            coordinate = 0
67
        else:
68
            coordinate = torch.randint(max_coordinate, size=(1,)).item()
69
        coordinates.append(coordinate)
70
    index_ini = np.array(coordinates, np.uint16)
71
    index_fin = index_ini + patch_size_array
72
    return index_ini, index_fin
73