check_difference_between_two_lists()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 11
rs 10
c 0
b 0
f 0
cc 2
nop 3
1
"""
2
Module for IO of files in relation to
3
data loading.
4
"""
5
import glob
6
import itertools as it
7
import os
8
import random
9
from typing import List, Tuple, Union
10
11
import h5py
12
13
14
def get_h5_sorted_keys(filename: str) -> List[str]:
15
    """
16
    Function to get sorted keys from filename
17
    :param filename: h5 file.
18
    :return: sorted keys of h5 file.
19
    """
20
    with h5py.File(filename, "r") as h5_file:
21
        return sorted(h5_file.keys())
22
23
24
def get_sorted_file_paths_in_dir_with_suffix(
25
    dir_path: str, suffix: Union[str, List[str]]
26
) -> List[Tuple[str, ...]]:
27
    """
28
    Return the path of all files under the given directory.
29
30
    :param dir_path: path of the directory
31
    :param suffix: suffix of file names like h5, nii.gz, nii, should not start with .
32
    :return: list of relative file path, each element is (file_path, suffix)
33
    assuming the full path of the file is dir_path/file_path.suffix
34
    """
35
    if isinstance(suffix, str):
36
        suffix = [suffix]
37
    paths = []
38
    for suffix_i in suffix:
39
        # full_path is dir_path/file_path.suffix
40
        full_paths = glob.glob(
41
            os.path.join(dir_path, "**", "*." + suffix_i), recursive=True
42
        )
43
        file_paths = [
44
            os.path.relpath(path=p, start=dir_path)[: -(len(suffix_i) + 1)]
45
            for p in full_paths
46
        ]
47
        paths += [(p, suffix_i) for p in file_paths]
48
    return sorted(paths)
49
50
51
def check_difference_between_two_lists(list1: list, list2: list, name: str):
52
    """
53
    Raise error if two lists are not identical
54
55
    :param list1: list
56
    :param list2: list
57
    :param name: name to be printed in case of difference
58
    """
59
    diff = [(x, y) for x, y in it.zip_longest(list1, list2) if x != y]
60
    if len(diff) > 0:
61
        raise ValueError(f"{name} are not identical\n" f"difference are {diff}\n")
62
63
64
def get_label_indices(num_labels: int, sample_label: str) -> list:
65
    """
66
    Function to get sample label indices for a given number
67
    of labels and a sampling policy
68
    :param num_labels: int number of labels
69
    :param sample_label: method for sampling the labels
70
    :return: list of labels defined by the sampling method.
71
    """
72
    if sample_label == "sample":  # sample a random label
73
        return [random.randrange(num_labels)]
74
    elif sample_label == "all":  # use all labels
75
        return list(range(num_labels))
76
    else:
77
        raise ValueError("Unknown label sampling policy %s" % sample_label)
78