deepreg.dataset.loader.unpaired_loader   A
last analyzed

Complexity

Total Complexity 8

Size/Duplication

Total Lines 107
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 8
eloc 59
dl 0
loc 107
rs 10
c 0
b 0
f 0

4 Methods

Rating   Name   Duplication   Size   Complexity  
A UnpairedDataLoader.close() 0 7 2
A UnpairedDataLoader.sample_index_generator() 0 13 2
A UnpairedDataLoader.__init__() 0 44 2
A UnpairedDataLoader.validate_data_files() 0 13 2
1
"""
2
Load unpaired data.
3
Supported formats: h5 and Nifti.
4
Image data can be labeled or unlabeled.
5
"""
6
import random
7
from typing import List, Tuple, Union
8
9
from deepreg.dataset.loader.interface import (
10
    AbstractUnpairedDataLoader,
11
    GeneratorDataLoader,
12
)
13
from deepreg.dataset.util import check_difference_between_two_lists
14
from deepreg.registry import REGISTRY
15
16
17
@REGISTRY.register_data_loader(name="unpaired")
18
class UnpairedDataLoader(AbstractUnpairedDataLoader, GeneratorDataLoader):
19
    """
20
    Load unpaired data using given file loader. Handles both labeled
21
    and unlabeled cases.
22
    The function sample_index_generator needs to be defined for the
23
    GeneratorDataLoader class.
24
    """
25
26
    def __init__(
27
        self,
28
        file_loader,
29
        data_dir_paths: List[str],
30
        labeled: bool,
31
        sample_label: str,
32
        seed: int,
33
        image_shape: Union[Tuple[int, ...], List[int]],
34
    ):
35
        """
36
        Load data which are unpaired, labeled or unlabeled.
37
38
        :param file_loader:
39
        :param data_dir_paths: paths of the directories storing data,
40
            the data are saved under four different sub-directories: images, labels
41
        :param labeled: whether the data is labeled.
42
        :param sample_label:
43
        :param seed:
44
        :param image_shape: (width, height, depth)
45
        """
46
        super().__init__(
47
            image_shape=image_shape,
48
            labeled=labeled,
49
            sample_label=sample_label,
50
            seed=seed,
51
        )
52
        assert isinstance(
53
            data_dir_paths, list
54
        ), f"data_dir_paths must be list of strings, got {data_dir_paths}"
55
        loader_image = file_loader(
56
            dir_paths=data_dir_paths, name="images", grouped=False
57
        )
58
        self.loader_moving_image = loader_image
59
        self.loader_fixed_image = loader_image
60
        if self.labeled:
61
            loader_label = file_loader(
62
                dir_paths=data_dir_paths, name="labels", grouped=False
63
            )
64
            self.loader_moving_label = loader_label
65
            self.loader_fixed_label = loader_label
66
        self.validate_data_files()
67
68
        self.num_images = self.loader_moving_image.get_num_images()
69
        self._num_samples = self.num_images // 2
70
71
    def validate_data_files(self):
72
        """
73
        Verify all loader have the same files.
74
        Since fixed and moving loaders come from the same file_loader,
75
        there is no need to check both (avoid duplicate).
76
        """
77
        if self.labeled:
78
            image_ids = self.loader_moving_image.get_data_ids()
79
            label_ids = self.loader_moving_label.get_data_ids()
80
            check_difference_between_two_lists(
81
                list1=image_ids,
82
                list2=label_ids,
83
                name="images and labels in unpaired loader",
84
            )
85
86
    def sample_index_generator(self):
87
        """
88
        Generates sample indexes to load data using the
89
        GeneratorDataLoader class.
90
        """
91
        image_indices = [i for i in range(self.num_images)]
92
        random.Random(self.seed).shuffle(image_indices)
93
        for sample_index in range(self.num_samples):
94
            moving_index, fixed_index = (
95
                image_indices[2 * sample_index],
96
                image_indices[2 * sample_index + 1],
97
            )
98
            yield moving_index, fixed_index, [moving_index, fixed_index]
99
100
    def close(self):
101
        """
102
        Close the moving files opened by the file_loaders.
103
        """
104
        self.loader_moving_image.close()
105
        if self.labeled:
106
            self.loader_moving_label.close()
107