Completed
Push — main ( d3edf2...e8f714 )
by Yunguan
21s queued 14s
created

deepreg.dataset.load.get_data_loader()   C

Complexity

Conditions 10

Size

Total Lines 50
Code Lines 29

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 29
dl 0
loc 50
rs 5.9999
c 0
b 0
f 0
cc 10
nop 2

How to fix   Complexity   

Complexity

Complex classes like deepreg.dataset.load.get_data_loader() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import os
2
from copy import deepcopy
3
from typing import Optional
4
5
from deepreg.constant import KNOWN_DATA_SPLITS
6
from deepreg.dataset.loader.interface import DataLoader
7
from deepreg.registry import FILE_LOADER_CLASS, REGISTRY
8
9
10
def get_data_loader(data_config: dict, split: str) -> Optional[DataLoader]:
11
    """
12
    Return the corresponding data loader.
13
14
    Can't be placed in the same file of loader interfaces as it causes import cycle.
15
16
    :param data_config: a dictionary containing configuration for data
17
    :param split: must be train/valid/test
18
    :return: DataLoader or None, returns None if the split or dir is empty.
19
    """
20
    if split not in KNOWN_DATA_SPLITS:
21
        raise ValueError(f"split must be one of {KNOWN_DATA_SPLITS}, got {split}")
22
23
    if split not in data_config:
24
        return None
25
    data_dir_paths = data_config[split].get("dir", None)
26
    if data_dir_paths is None or data_dir_paths == "":
27
        return None
28
29
    if isinstance(data_dir_paths, str):
30
        data_dir_paths = [data_dir_paths]
31
    # replace ~ with user home path
32
    data_dir_paths = list(map(os.path.expanduser, data_dir_paths))
33
    for data_dir_path in data_dir_paths:
34
        if not os.path.isdir(data_dir_path):
35
            raise ValueError(
36
                f"Data directory path {data_dir_path} for split {split}"
37
                f" is not a directory or does not exist"
38
            )
39
40
    # prepare data loader config
41
    data_loader_config = deepcopy(data_config)
42
    data_loader_config = {
43
        k: v for k, v in data_loader_config.items() if k not in KNOWN_DATA_SPLITS
44
    }
45
    data_loader_config["name"] = data_loader_config.pop("type")
46
47
    default_args = dict(
48
        data_dir_paths=data_dir_paths,
49
        file_loader=REGISTRY.get(
50
            category=FILE_LOADER_CLASS, key=data_config[split]["format"]
51
        ),
52
        labeled=data_config[split]["labeled"],
53
        sample_label="sample" if split == "train" else "all",
54
        seed=None if split == "train" else 0,
55
    )
56
    data_loader: DataLoader = REGISTRY.build_data_loader(
57
        config=data_loader_config, default_args=default_args
58
    )
59
    return data_loader
60