Passed
Push — master ( fca4b7...25d6cf )
by Fernando
01:33
created

torchio.data.sampler.sampler.ImageSampler.copy_and_crop()   A

Complexity

Conditions 2

Size

Total Lines 17
Code Lines 15

Duplication

Lines 17
Ratio 100 %

Importance

Changes 0
Metric Value
cc 2
eloc 15
nop 3
dl 17
loc 17
rs 9.65
c 0
b 0
f 0
1
from typing import Tuple, Optional, Generator
2
3
import numpy as np
4
5
from ... import TypePatchSize
6
from ...data.subject import Subject
7
from ...utils import to_tuple
8
9
10
class PatchSampler:
11
    r"""Base class for TorchIO samplers.
12
13
    Args:
14
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
15
            of size :math:`d \times h \times w`.
16
            If a single number :math:`n` is provided, :math:`d = h = w = n`.
17
    """
18
    def __init__(self, patch_size: TypePatchSize):
19
        patch_size_array = np.array(to_tuple(patch_size, length=3))
20
        if np.any(patch_size_array < 1):
21
            message = (
22
                'Patch dimensions must be positive integers,'
23
                f' not {patch_size_array}'
24
            )
25
            raise ValueError(message)
26
        self.patch_size = patch_size_array.astype(np.uint16)
27
28
    def __call__(
29
            self,
30
            sample: Subject,
31
            num_patches: Optional[int] = None,
32
            ) -> Generator[Subject, None, None]:
33
        raise NotImplementedError
34
35
    def get_probability_map(self, sample: Subject):
36
        raise NotImplementedError
37
38
    def extract_patch(self):
39
        raise NotImplementedError
40
41
    @staticmethod
42
    def get_crop_transform(
43
            sample,
44
            index_ini,
45
            patch_size: TypePatchSize,
46
            ):
47
        from ...transforms.preprocessing.spatial.crop import Crop
48
        shape = np.array(sample.spatial_shape, dtype=np.uint16)
49
        index_ini = np.array(index_ini, dtype=np.uint16)
50
        patch_size = np.array(patch_size, dtype=np.uint16)
51
        index_fin = index_ini + patch_size
52
        crop_ini = index_ini.tolist()
53
        crop_fin = (shape - index_fin).tolist()
54
        TypeBounds = Tuple[int, int, int, int, int, int]
55
        start = ()
56
        cropping: TypeBounds = sum(zip(crop_ini, crop_fin), start)
57
        return Crop(cropping)
58