1 | """ |
||
2 | Load h5 files and associated information. |
||
3 | """ |
||
4 | import os |
||
5 | from typing import List, Tuple, Union |
||
6 | |||
7 | import h5py |
||
8 | import numpy as np |
||
9 | |||
10 | from deepreg.dataset.loader.interface import FileLoader |
||
11 | from deepreg.registry import REGISTRY |
||
12 | |||
13 | DATA_KEY_FORMAT = "group-{}-{}" |
||
14 | |||
15 | |||
16 | @REGISTRY.register_file_loader(name="h5") |
||
17 | class H5FileLoader(FileLoader): |
||
18 | """Generalized loader for h5 files.""" |
||
19 | |||
20 | def __init__(self, dir_paths: List[str], name: str, grouped: bool): |
||
21 | """ |
||
22 | Init. |
||
23 | |||
24 | :param dir_paths: path of h5 files. |
||
25 | :param name: name is used to identify the file names. |
||
26 | :param grouped: whether the data is grouped. |
||
27 | """ |
||
28 | super().__init__(dir_paths=dir_paths, name=name, grouped=grouped) |
||
29 | self.h5_files = None |
||
30 | self.data_path_splits = None |
||
31 | self.set_data_structure() |
||
32 | self.group_struct = None |
||
33 | if self.grouped: |
||
34 | self.set_group_structure() |
||
35 | |||
36 | def set_data_structure(self): |
||
37 | """ |
||
38 | Store the data structure in memory so that |
||
39 | we can retrieve data using data_index. |
||
40 | This function sets two attributes: |
||
41 | |||
42 | - h5_files, a dict such that h5_files[dir_path] = opened h5 file handle |
||
43 | - data_path_splits, a list of string tuples to identify path of data |
||
44 | |||
45 | - if grouped, a split is (dir_path, group_name, data_key) such that |
||
46 | data = h5_files[dir_path]["group-{group_name}-{data_key}"] |
||
47 | - if not grouped, a split is (dir_path, data_key) such that |
||
48 | data = h5_files[dir_path][data_key] |
||
49 | """ |
||
50 | h5_files = {} |
||
51 | data_path_splits = [] |
||
52 | for dir_path in self.dir_paths: |
||
53 | h5_file_path = os.path.join(dir_path, self.name + ".h5") |
||
54 | assert os.path.exists( |
||
55 | h5_file_path |
||
56 | ), f"h5 file {h5_file_path} does not exist" |
||
57 | h5_file = h5py.File(h5_file_path, "r") |
||
58 | h5_files[dir_path] = h5_file |
||
59 | |||
60 | if self.grouped: |
||
61 | # each element is (dir_path, group_name, data_key) |
||
62 | # check h5 file keys |
||
63 | key_splits = [k.split("-") for k in sorted(h5_file.keys())] |
||
64 | assert all( |
||
65 | [len(x) == 3 and x[0] == "group" for x in key_splits] |
||
66 | ), f"h5_file keys must be of form group-X-Y, got {key_splits}" |
||
67 | data_path_splits += [(dir_path, k[1], k[2]) for k in key_splits] |
||
68 | else: |
||
69 | # each element is (dir_path, data_key) |
||
70 | data_path_splits += [(dir_path, k) for k in sorted(h5_file.keys())] |
||
71 | if len(data_path_splits) == 0: |
||
72 | raise ValueError( |
||
73 | f"No data collected from {self.dir_paths} in H5FileLoader, " |
||
74 | f"please verify the path is correct." |
||
75 | ) |
||
76 | self.h5_files = h5_files |
||
77 | self.data_path_splits = data_path_splits |
||
78 | |||
79 | View Code Duplication | def set_group_structure(self): |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
80 | """ |
||
81 | Similar to NiftiLoader |
||
82 | as the first two tokens of a split forms a group_id. |
||
83 | Store the group structure in group_struct so that |
||
84 | group_struct[group_index] = list of data_index. |
||
85 | Retrieve data using (group_index, in_group_data_index). |
||
86 | data_index = group_struct[group_index][in_group_data_index]. |
||
87 | """ |
||
88 | # group_struct_dict[group_id] = list of data_index |
||
89 | group_struct_dict = {} |
||
90 | for data_index, split in enumerate(self.data_path_splits): |
||
91 | group_id = split[:2] |
||
92 | if group_id not in group_struct_dict.keys(): |
||
93 | group_struct_dict[group_id] = [] |
||
94 | group_struct_dict[group_id].append(data_index) |
||
95 | # group_struct[group_index] = list of data_index |
||
96 | group_struct = [] |
||
97 | for k in sorted(group_struct_dict.keys()): |
||
98 | group_struct.append(group_struct_dict[k]) |
||
99 | self.group_struct = group_struct |
||
100 | |||
101 | def get_data(self, index: Union[int, Tuple[int, ...]]) -> np.ndarray: |
||
102 | """ |
||
103 | Get one data array by specifying an index |
||
104 | |||
105 | :param index: the data index which is required |
||
106 | |||
107 | - for paired or unpaired, the index is one single int, data_index |
||
108 | - for grouped, the index is a tuple of two ints, |
||
109 | (group_index, in_group_data_index) |
||
110 | :returns arr: the data array at the specified index |
||
111 | """ |
||
112 | assert self.data_path_splits is not None |
||
113 | if isinstance(index, int): # paired or unpaired |
||
114 | assert not self.grouped |
||
115 | assert 0 <= index |
||
116 | dir_path, data_key = self.data_path_splits[index] |
||
117 | elif isinstance(index, tuple): |
||
118 | assert self.grouped |
||
119 | group_index, in_group_data_index = index |
||
120 | assert 0 <= group_index |
||
121 | assert 0 <= in_group_data_index |
||
122 | data_index = self.group_struct[group_index][in_group_data_index] |
||
123 | dir_path, group_name, data_key = self.data_path_splits[data_index] |
||
124 | data_key = DATA_KEY_FORMAT.format(group_name, data_key) |
||
125 | else: |
||
126 | raise ValueError( |
||
127 | f"index for H5FileLoader.get_data must be int, " |
||
128 | f"or tuple of length two, got {index}" |
||
129 | ) |
||
130 | arr = np.asarray(self.h5_files[dir_path][data_key], dtype=np.float32) |
||
131 | if len(arr.shape) == 4 and arr.shape[3] == 1: |
||
132 | # for labels, if there's only one label, remove the last dimension |
||
133 | # currently have not encountered |
||
134 | arr = arr[:, :, :, 0] # pragma: no cover |
||
135 | return arr |
||
136 | |||
137 | def get_data_ids(self) -> List: |
||
138 | """ |
||
139 | Get the unique IDs of data in this data set to |
||
140 | verify consistency between |
||
141 | images and label, moving and fixed. |
||
142 | |||
143 | :return: data_path_splits as the data can be identified |
||
144 | using dir_path and data_key |
||
145 | """ |
||
146 | return self.data_path_splits # type: ignore |
||
147 | |||
148 | def get_num_images(self) -> int: |
||
149 | """ |
||
150 | :return: int, number of images in this data set |
||
151 | """ |
||
152 | return len(self.data_path_splits) # type: ignore |
||
153 | |||
154 | def close(self): |
||
155 | """Close opened h5 file handles.""" |
||
156 | for f in self.h5_files.values(): |
||
157 | f.close() |
||
158 |