1
|
|
|
import os |
2
|
|
|
from copy import deepcopy |
3
|
|
|
from typing import Optional |
4
|
|
|
|
5
|
|
|
from deepreg.constant import KNOWN_DATA_SPLITS |
6
|
|
|
from deepreg.dataset.loader.interface import DataLoader |
7
|
|
|
from deepreg.registry import FILE_LOADER_CLASS, REGISTRY |
8
|
|
|
|
9
|
|
|
|
10
|
|
|
def get_data_loader(data_config: dict, split: str) -> Optional[DataLoader]: |
11
|
|
|
""" |
12
|
|
|
Return the corresponding data loader. |
13
|
|
|
|
14
|
|
|
Can't be placed in the same file of loader interfaces as it causes import cycle. |
15
|
|
|
|
16
|
|
|
:param data_config: a dictionary containing configuration for data |
17
|
|
|
:param split: must be train/valid/test |
18
|
|
|
:return: DataLoader or None, returns None if the split or dir is empty. |
19
|
|
|
""" |
20
|
|
|
if split not in KNOWN_DATA_SPLITS: |
21
|
|
|
raise ValueError(f"split must be one of {KNOWN_DATA_SPLITS}, got {split}") |
22
|
|
|
|
23
|
|
|
if split not in data_config: |
24
|
|
|
return None |
25
|
|
|
data_dir_paths = data_config[split].get("dir", None) |
26
|
|
|
if data_dir_paths is None or data_dir_paths == "": |
27
|
|
|
return None |
28
|
|
|
|
29
|
|
|
if isinstance(data_dir_paths, str): |
30
|
|
|
data_dir_paths = [data_dir_paths] |
31
|
|
|
# replace ~ with user home path |
32
|
|
|
data_dir_paths = list(map(os.path.expanduser, data_dir_paths)) |
33
|
|
|
for data_dir_path in data_dir_paths: |
34
|
|
|
if not os.path.isdir(data_dir_path): |
35
|
|
|
raise ValueError( |
36
|
|
|
f"Data directory path {data_dir_path} for split {split}" |
37
|
|
|
f" is not a directory or does not exist" |
38
|
|
|
) |
39
|
|
|
|
40
|
|
|
# prepare data loader config |
41
|
|
|
data_loader_config = deepcopy(data_config) |
42
|
|
|
data_loader_config = { |
43
|
|
|
k: v for k, v in data_loader_config.items() if k not in KNOWN_DATA_SPLITS |
44
|
|
|
} |
45
|
|
|
data_loader_config["name"] = data_loader_config.pop("type") |
46
|
|
|
|
47
|
|
|
default_args = dict( |
48
|
|
|
data_dir_paths=data_dir_paths, |
49
|
|
|
file_loader=REGISTRY.get( |
50
|
|
|
category=FILE_LOADER_CLASS, key=data_config[split]["format"] |
51
|
|
|
), |
52
|
|
|
labeled=data_config[split]["labeled"], |
53
|
|
|
sample_label="sample" if split == "train" else "all", |
54
|
|
|
seed=None if split == "train" else 0, |
55
|
|
|
) |
56
|
|
|
data_loader: DataLoader = REGISTRY.build_data_loader( |
57
|
|
|
config=data_loader_config, default_args=default_args |
58
|
|
|
) |
59
|
|
|
return data_loader |
60
|
|
|
|