Passed
Push — master ( c9c9a5...7d9f03 )
by Fernando
01:40
created

  A

Complexity

Conditions 1

Size

Total Lines 16
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 15
nop 3
dl 0
loc 16
rs 9.65
c 0
b 0
f 0
1
from typing import Optional, Generator
2
3
import numpy as np
4
5
from ...typing import TypePatchSize, TypeTripletInt
6
from ...data.subject import Subject
7
from ...utils import to_tuple
8
9
10
class PatchSampler:
11
    r"""Base class for TorchIO samplers.
12
13
    Args:
14
        patch_size: Tuple of integers :math:`(w, h, d)` to generate patches
15
            of size :math:`w \times h \times d`.
16
            If a single number :math:`n` is provided, :math:`w = h = d = n`.
17
18
    .. warning:: This is an abstract class that should only be instantiated
19
        using child classes such as :class:`~torchio.data.UniformSampler` and
20
        :class:`~torchio.data.WeightedSampler`.
21
    """
22
    def __init__(self, patch_size: TypePatchSize):
23
        patch_size_array = np.array(to_tuple(patch_size, length=3))
24
        for n in patch_size_array:
25
            if n < 1 or not isinstance(n, (int, np.integer)):
26
                message = (
27
                    'Patch dimensions must be positive integers,'
28
                    f' not {patch_size_array}'
29
                )
30
                raise ValueError(message)
31
        self.patch_size = patch_size_array.astype(np.uint16)
32
33
    def extract_patch(
34
            self,
35
            subject: Subject,
36
            index_ini: TypeTripletInt,
37
            ) -> Subject:
38
        cropped_subject = self.crop(subject, index_ini, self.patch_size)
39
        cropped_subject['index_ini'] = np.array(index_ini).astype(int)
40
        return cropped_subject
41
42
    def crop(
43
            self,
44
            subject: Subject,
45
            index_ini: TypeTripletInt,
46
            patch_size: TypeTripletInt,
47
            ) -> Subject:
48
        transform = self._get_crop_transform(subject, index_ini, patch_size)
49
        return transform(subject)
50
51
    @staticmethod
52
    def _get_crop_transform(
53
            subject,
54
            index_ini: TypeTripletInt,
55
            patch_size: TypePatchSize,
56
            ):
57
        from ...transforms.preprocessing.spatial.crop import Crop
58
        shape = np.array(subject.spatial_shape, dtype=np.uint16)
59
        index_ini = np.array(index_ini, dtype=np.uint16)
60
        patch_size = np.array(patch_size, dtype=np.uint16)
61
        assert len(index_ini) == 3
62
        assert len(patch_size) == 3
63
        index_fin = index_ini + patch_size
64
        crop_ini = index_ini.tolist()
65
        crop_fin = (shape - index_fin).tolist()
66
        start = ()
67
        cropping = sum(zip(crop_ini, crop_fin), start)
68
        return Crop(cropping)
69
70
    def __call__(
71
            self,
72
            subject: Subject,
73
            num_patches: Optional[int] = None,
74
            ) -> Generator[Subject, None, None]:
75
        subject.check_consistent_space()
76
        if np.any(self.patch_size > subject.spatial_shape):
77
            message = (
78
                f'Patch size {tuple(self.patch_size)} cannot be'
79
                f' larger than image size {tuple(subject.spatial_shape)}'
80
            )
81
            raise RuntimeError(message)
82
        kwargs = {} if num_patches is None else {'num_patches': num_patches}
83
        return self._generate_patches(subject, **kwargs)
84
85
    def _generate_patches(
86
            self,
87
            subject: Subject,
88
            num_patches: Optional[int] = None,
89
            ) -> Generator[Subject, None, None]:
90
        raise NotImplementedError
91
92
93
class RandomSampler(PatchSampler):
94
    r"""Base class for random samplers.
95
96
    Args:
97
        patch_size: Tuple of integers :math:`(w, h, d)` to generate patches
98
            of size :math:`w \times h \times d`.
99
            If a single number :math:`n` is provided, :math:`w = h = d = n`.
100
    """
101
    def get_probability_map(self, subject: Subject):
102
        raise NotImplementedError
103