Issues (32)

deepreg/dataset/loader/h5_loader.py (1 issue)

1
"""
2
Load h5 files and associated information.
3
"""
4
import os
5
from typing import List, Tuple, Union
6
7
import h5py
8
import numpy as np
9
10
from deepreg.dataset.loader.interface import FileLoader
11
from deepreg.registry import REGISTRY
12
13
DATA_KEY_FORMAT = "group-{}-{}"
14
15
16
@REGISTRY.register_file_loader(name="h5")
17
class H5FileLoader(FileLoader):
18
    """Generalized loader for h5 files."""
19
20
    def __init__(self, dir_paths: List[str], name: str, grouped: bool):
21
        """
22
        Init.
23
24
        :param dir_paths: path of h5 files.
25
        :param name: name is used to identify the file names.
26
        :param grouped: whether the data is grouped.
27
        """
28
        super().__init__(dir_paths=dir_paths, name=name, grouped=grouped)
29
        self.h5_files = None
30
        self.data_path_splits = None
31
        self.set_data_structure()
32
        self.group_struct = None
33
        if self.grouped:
34
            self.set_group_structure()
35
36
    def set_data_structure(self):
37
        """
38
        Store the data structure in  memory so that
39
        we can retrieve data using data_index.
40
        This function sets two attributes:
41
42
        - h5_files, a dict such that h5_files[dir_path] = opened h5 file handle
43
        - data_path_splits, a list of string tuples to identify path of data
44
45
          - if grouped, a split is (dir_path, group_name, data_key) such that
46
            data = h5_files[dir_path]["group-{group_name}-{data_key}"]
47
          - if not grouped, a split is (dir_path, data_key) such that
48
            data = h5_files[dir_path][data_key]
49
        """
50
        h5_files = {}
51
        data_path_splits = []
52
        for dir_path in self.dir_paths:
53
            h5_file_path = os.path.join(dir_path, self.name + ".h5")
54
            assert os.path.exists(
55
                h5_file_path
56
            ), f"h5 file {h5_file_path} does not exist"
57
            h5_file = h5py.File(h5_file_path, "r")
58
            h5_files[dir_path] = h5_file
59
60
            if self.grouped:
61
                # each element is (dir_path, group_name, data_key)
62
                # check h5 file keys
63
                key_splits = [k.split("-") for k in sorted(h5_file.keys())]
64
                assert all(
65
                    [len(x) == 3 and x[0] == "group" for x in key_splits]
66
                ), f"h5_file keys must be of form group-X-Y, got {key_splits}"
67
                data_path_splits += [(dir_path, k[1], k[2]) for k in key_splits]
68
            else:
69
                # each element is (dir_path, data_key)
70
                data_path_splits += [(dir_path, k) for k in sorted(h5_file.keys())]
71
        if len(data_path_splits) == 0:
72
            raise ValueError(
73
                f"No data collected from {self.dir_paths} in H5FileLoader, "
74
                f"please verify the path is correct."
75
            )
76
        self.h5_files = h5_files
77
        self.data_path_splits = data_path_splits
78
79 View Code Duplication
    def set_group_structure(self):
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
80
        """
81
        Similar to NiftiLoader
82
        as the first two tokens of a split forms a group_id.
83
        Store the group structure in group_struct so that
84
        group_struct[group_index] = list of data_index.
85
        Retrieve data using (group_index, in_group_data_index).
86
        data_index = group_struct[group_index][in_group_data_index].
87
        """
88
        # group_struct_dict[group_id] = list of data_index
89
        group_struct_dict = {}
90
        for data_index, split in enumerate(self.data_path_splits):
91
            group_id = split[:2]
92
            if group_id not in group_struct_dict.keys():
93
                group_struct_dict[group_id] = []
94
            group_struct_dict[group_id].append(data_index)
95
        # group_struct[group_index] = list of data_index
96
        group_struct = []
97
        for k in sorted(group_struct_dict.keys()):
98
            group_struct.append(group_struct_dict[k])
99
        self.group_struct = group_struct
100
101
    def get_data(self, index: Union[int, Tuple[int, ...]]) -> np.ndarray:
102
        """
103
        Get one data array by specifying an index
104
105
        :param index: the data index which is required
106
107
          - for paired or unpaired, the index is one single int, data_index
108
          - for grouped, the index is a tuple of two ints,
109
            (group_index, in_group_data_index)
110
        :returns arr: the data array at the specified index
111
        """
112
        assert self.data_path_splits is not None
113
        if isinstance(index, int):  # paired or unpaired
114
            assert not self.grouped
115
            assert 0 <= index
116
            dir_path, data_key = self.data_path_splits[index]
117
        elif isinstance(index, tuple):
118
            assert self.grouped
119
            group_index, in_group_data_index = index
120
            assert 0 <= group_index
121
            assert 0 <= in_group_data_index
122
            data_index = self.group_struct[group_index][in_group_data_index]
123
            dir_path, group_name, data_key = self.data_path_splits[data_index]
124
            data_key = DATA_KEY_FORMAT.format(group_name, data_key)
125
        else:
126
            raise ValueError(
127
                f"index for H5FileLoader.get_data must be int, "
128
                f"or tuple of length two, got {index}"
129
            )
130
        arr = np.asarray(self.h5_files[dir_path][data_key], dtype=np.float32)
131
        if len(arr.shape) == 4 and arr.shape[3] == 1:
132
            # for labels, if there's only one label, remove the last dimension
133
            # currently have not encountered
134
            arr = arr[:, :, :, 0]  # pragma: no cover
135
        return arr
136
137
    def get_data_ids(self) -> List:
138
        """
139
        Get the unique IDs of data in this data set to
140
        verify consistency between
141
        images and label, moving and fixed.
142
143
        :return: data_path_splits as the data can be identified
144
            using dir_path and data_key
145
        """
146
        return self.data_path_splits  # type: ignore
147
148
    def get_num_images(self) -> int:
149
        """
150
        :return: int, number of images in this data set
151
        """
152
        return len(self.data_path_splits)  # type: ignore
153
154
    def close(self):
155
        """Close opened h5 file handles."""
156
        for f in self.h5_files.values():
157
            f.close()
158