|
1
|
|
|
""" |
|
2
|
|
|
Module for IO of files in relation to |
|
3
|
|
|
data loading. |
|
4
|
|
|
""" |
|
5
|
|
|
import glob |
|
6
|
|
|
import itertools as it |
|
7
|
|
|
import os |
|
8
|
|
|
import random |
|
9
|
|
|
from typing import List, Tuple, Union |
|
10
|
|
|
|
|
11
|
|
|
import h5py |
|
12
|
|
|
|
|
13
|
|
|
|
|
14
|
|
|
def get_h5_sorted_keys(filename: str) -> List[str]: |
|
15
|
|
|
""" |
|
16
|
|
|
Function to get sorted keys from filename |
|
17
|
|
|
:param filename: h5 file. |
|
18
|
|
|
:return: sorted keys of h5 file. |
|
19
|
|
|
""" |
|
20
|
|
|
with h5py.File(filename, "r") as h5_file: |
|
21
|
|
|
return sorted(h5_file.keys()) |
|
22
|
|
|
|
|
23
|
|
|
|
|
24
|
|
|
def get_sorted_file_paths_in_dir_with_suffix( |
|
25
|
|
|
dir_path: str, suffix: Union[str, List[str]] |
|
26
|
|
|
) -> List[Tuple[str, ...]]: |
|
27
|
|
|
""" |
|
28
|
|
|
Return the path of all files under the given directory. |
|
29
|
|
|
|
|
30
|
|
|
:param dir_path: path of the directory |
|
31
|
|
|
:param suffix: suffix of file names like h5, nii.gz, nii, should not start with . |
|
32
|
|
|
:return: list of relative file path, each element is (file_path, suffix) |
|
33
|
|
|
assuming the full path of the file is dir_path/file_path.suffix |
|
34
|
|
|
""" |
|
35
|
|
|
if isinstance(suffix, str): |
|
36
|
|
|
suffix = [suffix] |
|
37
|
|
|
paths = [] |
|
38
|
|
|
for suffix_i in suffix: |
|
39
|
|
|
# full_path is dir_path/file_path.suffix |
|
40
|
|
|
full_paths = glob.glob( |
|
41
|
|
|
os.path.join(dir_path, "**", "*." + suffix_i), recursive=True |
|
42
|
|
|
) |
|
43
|
|
|
file_paths = [ |
|
44
|
|
|
os.path.relpath(path=p, start=dir_path)[: -(len(suffix_i) + 1)] |
|
45
|
|
|
for p in full_paths |
|
46
|
|
|
] |
|
47
|
|
|
paths += [(p, suffix_i) for p in file_paths] |
|
48
|
|
|
return sorted(paths) |
|
49
|
|
|
|
|
50
|
|
|
|
|
51
|
|
|
def check_difference_between_two_lists(list1: list, list2: list, name: str): |
|
52
|
|
|
""" |
|
53
|
|
|
Raise error if two lists are not identical |
|
54
|
|
|
|
|
55
|
|
|
:param list1: list |
|
56
|
|
|
:param list2: list |
|
57
|
|
|
:param name: name to be printed in case of difference |
|
58
|
|
|
""" |
|
59
|
|
|
diff = [(x, y) for x, y in it.zip_longest(list1, list2) if x != y] |
|
60
|
|
|
if len(diff) > 0: |
|
61
|
|
|
raise ValueError(f"{name} are not identical\n" f"difference are {diff}\n") |
|
62
|
|
|
|
|
63
|
|
|
|
|
64
|
|
|
def get_label_indices(num_labels: int, sample_label: str) -> list: |
|
65
|
|
|
""" |
|
66
|
|
|
Function to get sample label indices for a given number |
|
67
|
|
|
of labels and a sampling policy |
|
68
|
|
|
:param num_labels: int number of labels |
|
69
|
|
|
:param sample_label: method for sampling the labels |
|
70
|
|
|
:return: list of labels defined by the sampling method. |
|
71
|
|
|
""" |
|
72
|
|
|
if sample_label == "sample": # sample a random label |
|
73
|
|
|
return [random.randrange(num_labels)] |
|
74
|
|
|
elif sample_label == "all": # use all labels |
|
75
|
|
|
return list(range(num_labels)) |
|
76
|
|
|
else: |
|
77
|
|
|
raise ValueError("Unknown label sampling policy %s" % sample_label) |
|
78
|
|
|
|