Passed
Push — master ( 7b848f...cac223 )
by Fernando
02:39
created

LabelSampler.extract_patch_generator()   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 7
rs 10
c 0
b 0
f 0
cc 2
nop 3
1
from typing import Generator
2
from .sampler import ImageSampler, crop
3
from ... import DATA, LABEL, TYPE
4
from ..subject import Subject
5
6
7
class LabelSampler(ImageSampler):
8
    r"""Extract random patches containing labeled voxels.
9
10
    This iterable dataset yields patches that contain at least one voxel
11
    without background.
12
13
    It extracts the label data from the first image in the sample with type
14
    :py:attr:`torchio.LABEL`.
15
16
    Args:
17
        sample: Sample generated by a
18
            :py:class:`~torchio.data.dataset.ImagesDataset`, from which image
19
            patches will be extracted.
20
        patch_size: Tuple of integers :math:`(d, h, w)` to generate patches
21
            of size :math:`d \times h \times w`.
22
            If a single number :math:`n` is provided,
23
            :math:`d = h = w = n`.
24
25
    .. warning:: For now, this implementation is not efficient because it uses
26
        brute force to look for foreground voxels. It the number of
27
        non-background voxels is very small, this sampler will be slow.
28
    """
29
    # pylint: disable=abstract-method
30
    def extract_patch_generator(
31
            self,
32
            sample: Subject,
33
            patch_size,
34
            ) -> Generator[dict, None, None]:
35
        while True:
36
            yield self.extract_patch(sample, patch_size)
37
38
    @staticmethod
39
    def get_first_label_image_dict(sample: Subject):
40
        for image_dict in sample.get_images(intensity_only=False):
41
            if image_dict[TYPE] == LABEL:
42
                label_image_dict = image_dict
43
                break
44
        else:
45
            raise ValueError('No images of type torchio.LABEL found in sample')
46
        return label_image_dict
47
48
    def extract_patch(self, sample: Subject, patch_size):
49
        has_label = False
50
        label_image_data = self.get_first_label_image_dict(sample)[DATA]
51
        while not has_label:
52
            index_ini, index_fin = self.get_random_indices(sample, patch_size)
53
            patch_label = crop(label_image_data, index_ini, index_fin)
54
            has_label = patch_label.sum() > 0
55
        cropped_sample = self.copy_and_crop(
56
            sample,
57
            index_ini,
0 ignored issues
show
introduced by
The variable index_ini does not seem to be defined in case the while loop on line 51 is not entered. Are you sure this can never be the case?
Loading history...
58
            index_fin,
0 ignored issues
show
introduced by
The variable index_fin does not seem to be defined in case the while loop on line 51 is not entered. Are you sure this can never be the case?
Loading history...
59
        )
60
        return cropped_sample
61