Passed
Pull Request — master (#175)
by Fernando
01:04
created

PatchSampler.get_crop_transform()   A

Complexity

Conditions 1

Size

Total Lines 11
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 11
nop 3
dl 0
loc 11
rs 9.85
c 0
b 0
f 0
1
from typing import Union, Sequence, Tuple
2
3
import torch
4
import numpy as np
5
6
from ...utils import to_tuple
7
8
9
class PatchSampler:
10
    r"""Base class for TorchIO samplers.
11
12
    Args:
13
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
14
            of size :math:`d \times h \times w`.
15
            If a single number :math:`n` is provided, :math:`d = h = w = n`.
16
    """
17
    def __init__(self, patch_size: Union[int, Sequence[int]]):
18
        patch_size = np.array(to_tuple(patch_size, length=3))
19
        if np.any(patch_size < 1):
20
            message = (
21
                'Patch dimensions must be positive integers,'
22
                f' not {patch_size}'
23
            )
24
            raise ValueError(message)
25
        self.patch_size = patch_size.astype(np.uint16)
26
27
    def __call__(self):
28
        raise NotImplementedError
29
30
    def get_probability_map(self):
31
        raise NotImplementedError
32
33
    def extract_patch(self):
34
        raise NotImplementedError
35
36
    @staticmethod
37
    def get_crop_transform(sample, index_ini, patch_size):
38
        from ...transforms.preprocessing.spatial.crop import Crop
39
        shape = np.array(sample.spatial_shape, dtype=np.uint16)
40
        index_ini = np.array(index_ini, dtype=np.uint16)
41
        patch_size = np.array(patch_size, dtype=np.uint16)
42
        index_fin = index_ini + patch_size
43
        crop_ini = index_ini.tolist()
44
        crop_fin = (shape - index_fin).tolist()
45
        cropping = sum(zip(crop_ini, crop_fin), start=())
46
        return Crop(cropping)
47