Passed
Pull Request — master (#182)
by Fernando
01:18
created

PatchSampler.__call__()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 5
nop 3
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
from typing import Tuple, Optional, Generator
2
3
import numpy as np
4
5
from ... import TypePatchSize
6
from ...data.subject import Subject
7
from ...utils import to_tuple
8
9
10
class PatchSampler:
11
    r"""Base class for TorchIO samplers.
12
13
    Args:
14
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
15
            of size :math:`d \times h \times w`.
16
            If a single number :math:`n` is provided, :math:`d = h = w = n`.
17
    """
18
    def __init__(self, patch_size: TypePatchSize):
19
        patch_size_array = np.array(to_tuple(patch_size, length=3))
20
        if np.any(patch_size_array < 1):
21
            message = (
22
                'Patch dimensions must be positive integers,'
23
                f' not {patch_size_array}'
24
            )
25
            raise ValueError(message)
26
        self.patch_size = patch_size_array.astype(np.uint16)
27
28
    def extract_patch(self):
29
        raise NotImplementedError
30
31
    @staticmethod
32
    def get_crop_transform(
33
            image_size,
34
            index_ini,
35
            patch_size: TypePatchSize,
36
            ):
37
        from ...transforms.preprocessing.spatial.crop import Crop
38
        image_size = np.array(image_size, dtype=np.uint16)
39
        index_ini = np.array(index_ini, dtype=np.uint16)
40
        patch_size = np.array(patch_size, dtype=np.uint16)
41
        index_fin = index_ini + patch_size
42
        crop_ini = index_ini.tolist()
43
        crop_fin = (image_size - index_fin).tolist()
44
        TypeBounds = Tuple[int, int, int, int, int, int]
45
        start = ()
46
        cropping: TypeBounds = sum(zip(crop_ini, crop_fin), start)
47
        return Crop(cropping)
48
49
50
class RandomSampler(PatchSampler):
51
    r"""Base class for TorchIO samplers.
52
53
    Args:
54
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
55
            of size :math:`d \times h \times w`.
56
            If a single number :math:`n` is provided, :math:`d = h = w = n`.
57
    """
58
    def __call__(
59
            self,
60
            sample: Subject,
61
            num_patches: Optional[int] = None,
62
            ) -> Generator[Subject, None, None]:
63
        raise NotImplementedError
64
65
    def get_probability_map(self, sample: Subject):
66
        raise NotImplementedError
67