Passed
Pull Request — master (#334)
by Fernando
01:13
created

PatchSampler.get_crop_transform()   A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 15
nop 3
dl 0
loc 16
rs 9.65
c 0
b 0
f 0
1
import copy
2
from typing import Tuple, Optional, Generator
3
4
import numpy as np
5
6
from ... import TypePatchSize, TypeTripletInt
7
from ...data.subject import Subject
8
from ...utils import to_tuple
9
10
11
class PatchSampler:
12
    r"""Base class for TorchIO samplers.
13
14
    Args:
15
        patch_size: Tuple of integers :math:`(w, h, d)` to generate patches
16
            of size :math:`h \times w \times d`.
17
            If a single number :math:`n` is provided, :math:`w = h = d = n`.
18
    """
19
    def __init__(self, patch_size: TypePatchSize):
20
        patch_size_array = np.array(to_tuple(patch_size, length=3))
21
        for n in patch_size_array:
22
            if n < 1 or not isinstance(n, (int, np.integer)):
23
                message = (
24
                    'Patch dimensions must be positive integers,'
25
                    f' not {patch_size_array}'
26
                )
27
                raise ValueError(message)
28
        self.patch_size = patch_size_array.astype(np.uint16)
29
30
    def extract_patch(
31
            self,
32
            subject: Subject,
33
            index_ini: TypeTripletInt,
34
            ) -> Subject:
35
        cropped_subject = self.crop(subject, index_ini, self.patch_size)
36
        cropped_subject['index_ini'] = np.array(index_ini).astype(int)
37
        return cropped_subject
38
39
    def crop(
40
            self,
41
            subject: Subject,
42
            index_ini: TypeTripletInt,
43
            patch_size: TypeTripletInt,
44
            ) -> Subject:
45
        transform = self.get_crop_transform(subject, index_ini, patch_size)
46
        return transform(subject)
47
48
    @staticmethod
49
    def get_crop_transform(
50
            subject,
51
            index_ini,
52
            patch_size: TypePatchSize,
53
            ):
54
        from ...transforms.preprocessing.spatial.crop import Crop
55
        shape = np.array(subject.spatial_shape, dtype=np.uint16)
56
        index_ini = np.array(index_ini, dtype=np.uint16)
57
        patch_size = np.array(patch_size, dtype=np.uint16)
58
        index_fin = index_ini + patch_size
59
        crop_ini = index_ini.tolist()
60
        crop_fin = (shape - index_fin).tolist()
61
        start = ()
62
        cropping = sum(zip(crop_ini, crop_fin), start)
63
        return Crop(cropping)
64
65
66
class RandomSampler(PatchSampler):
67
    r"""Base class for TorchIO samplers.
68
69
    Args:
70
        patch_size: Tuple of integers :math:`(w, h, d)` to generate patches
71
            of size :math:`h \times w \times d`.
72
            If a single number :math:`n` is provided, :math:`w = h = d = n`.
73
    """
74
    def __call__(
75
            self,
76
            sample: Subject,
77
            num_patches: Optional[int] = None,
78
            ) -> Generator[Subject, None, None]:
79
        raise NotImplementedError
80
81
    def get_probability_map(self, sample: Subject):
82
        raise NotImplementedError
83