Passed
Push — master ( 4f7fe6...98f4df )
by Fernando
01:12
created

ImageSampler.get_stream()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 3
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
import copy
2
from typing import Union, Sequence, Generator, Tuple
3
4
import numpy as np
5
import torch
6
from torch.utils.data import IterableDataset
7
8
from ...torchio import DATA
9
from ...utils import to_tuple
10
from ..subject import Subject
11
12
13
class ImageSampler(IterableDataset):
14
    r"""Extract random patches from a volume.
15
16
    Args:
17
        sample: Sample generated by a
18
            :py:class:`~torchio.data.dataset.ImagesDataset`, from which image
19
            patches will be extracted.
20
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
21
            of size :math:`d \times h \times w`.
22
            If a single number :math:`n` is provided, :math:`d = h = w = n`.
23
    """
24
    def __init__(self, sample: Subject, patch_size: Union[int, Sequence[int]]):
25
        self.sample = sample
26
        patch_size = to_tuple(patch_size, length=3)
27
        self.patch_size = np.array(patch_size, dtype=np.uint16)
28
29
    def __iter__(self) -> Generator[Subject, None, None]:
30
        while True:
31
            yield self.extract_patch()
32
33
    def extract_patch(self) -> Subject:
34
        index_ini, index_fin = self.get_random_indices(
35
            self.sample, self.patch_size)
36
        cropped_sample = self.copy_and_crop(
37
            self.sample,
38
            index_ini,
39
            index_fin,
40
        )
41
        return cropped_sample
42
43
    @staticmethod
44
    def get_random_indices(sample: Subject, patch_size: Tuple[int, int, int]):
45
        # Assume all images in sample have the same shape
46
        sample.check_consistent_shape()
47
        first_image_name = list(sample.keys())[0]
48
        first_image_array = sample[first_image_name][DATA]
49
        # first_image_array should have shape (1, H, W, D)
50
        shape = np.array(first_image_array.shape[1:], dtype=np.uint16)
51
        return get_random_indices_from_shape(shape, patch_size)
52
53 View Code Duplication
    @staticmethod
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
54
    def copy_and_crop(
55
            sample: Subject,
56
            index_ini: np.ndarray,
57
            index_fin: np.ndarray,
58
            ) -> dict:
59
        cropped_sample = copy.deepcopy(sample)
60
        iterable = sample.get_images_dict(intensity_only=False).items()
61
        for image_name, image in iterable:
62
            cropped_sample[image_name] = copy.deepcopy(image)
63
            sample_image_dict = image
64
            cropped_image_dict = cropped_sample[image_name]
65
            cropped_image_dict[DATA] = crop(
66
                sample_image_dict[DATA], index_ini, index_fin)
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable DATA does not seem to be defined.
Loading history...
67
        # torch doesn't like uint16
68
        cropped_sample['index_ini'] = index_ini.astype(int)
69
        return cropped_sample
70
71
72
def crop(
73
        image: Union[np.ndarray, torch.Tensor],
74
        index_ini: np.ndarray,
75
        index_fin: np.ndarray,
76
        ) -> Union[np.ndarray, torch.Tensor]:
77
    i_ini, j_ini, k_ini = index_ini
78
    i_fin, j_fin, k_fin = index_fin
79
    return image[..., i_ini:i_fin, j_ini:j_fin, k_ini:k_fin]
80
81
82
def get_random_indices_from_shape(
83
        shape: Tuple[int, int, int],
84
        patch_size: Tuple[int, int, int],
85
        ) -> Tuple[np.ndarray, np.ndarray]:
86
    shape_array = np.array(shape)
87
    patch_size_array = np.array(patch_size)
88
    max_index_ini = shape_array - patch_size_array
89
    if (max_index_ini < 0).any():
90
        message = (
91
            f'Patch size {patch_size} must not be'
92
            f' larger than image size {shape}'
93
        )
94
        raise ValueError(message)
95
    max_index_ini = max_index_ini.astype(np.uint16)
96
    coordinates = []
97
    for max_coordinate in max_index_ini.tolist():
98
        if max_coordinate == 0:
99
            coordinate = 0
100
        else:
101
            coordinate = torch.randint(max_coordinate, size=(1,)).item()
102
        coordinates.append(coordinate)
103
    index_ini = np.array(coordinates, np.uint16)
104
    index_fin = index_ini + patch_size_array
105
    return index_ini, index_fin
106