Completed
Push — main ( 183d7f...45ab67 )
by Yunguan
19s queued 13s
created

deepreg.config.parser.config_sanity_check()   C

Complexity

Conditions 9

Size

Total Lines 39
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 21
dl 0
loc 39
rs 6.6666
c 0
b 0
f 0
cc 9
nop 1
1
import logging
2
import os
3
from typing import Dict, List, Union
4
5
import yaml
6
7
from deepreg.config.v011 import parse_v011
8
9
10
def update_nested_dict(d: Dict, u: Dict) -> Dict:
11
    """
12
    Merge two dicts.
13
14
    https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth
15
16
    :param d: dict to be overwritten in case of conflicts.
17
    :param u: dict to be merged into d.
18
    :return:
19
    """
20
21
    for k, v in u.items():
22
        if isinstance(v, dict):
23
            d[k] = update_nested_dict(d.get(k, {}), v)
24
        else:
25
            d[k] = v
26
    return d
27
28
29
def load_configs(config_path: Union[str, List[str]]) -> Dict:
30
    """
31
    Load multiple configs and update the nested dictionary.
32
33
    :param config_path: list of paths or one path.
34
    :return: the loaded config
35
    """
36
    if isinstance(config_path, str):
37
        config_path = [config_path]
38
    # replace ~ with user home path
39
    config_path = [os.path.expanduser(x) for x in config_path]
40
    config: Dict = {}
41
    for config_path_i in config_path:
42
        with open(config_path_i) as file:
43
            config_i = yaml.load(file, Loader=yaml.FullLoader)
44
        config = update_nested_dict(d=config, u=config_i)
45
    loaded_config = config_sanity_check(config)
46
47
    if loaded_config != config:
48
        # config got updated
49
        head, tail = os.path.split(config_path[0])
50
        filename = "updated_" + tail
51
        save(config=loaded_config, out_dir=head, filename=filename)
52
        logging.error(
53
            f"Used config is outdated. "
54
            f"An updated version has been saved at "
55
            f"{os.path.join(head, filename)}"
56
        )
57
58
    return loaded_config
59
60
61
def save(config: dict, out_dir: str, filename: str = "config.yaml"):
62
    """
63
    Save the config into a yaml file.
64
65
    :param config: configuration to be outputed
66
    :param out_dir: directory of the output file
67
    :param filename: name of the output file
68
    """
69
    assert filename.endswith(".yaml")
70
    with open(os.path.join(out_dir, filename), "w+") as f:
71
        f.write(yaml.dump(config))
72
73
74
def config_sanity_check(config: dict) -> dict:
75
    """
76
    Check if the given config satisfies the requirements.
77
78
    :param config: entire config.
79
    """
80
81
    # check data
82
    data_config = config["dataset"]
83
84
    if data_config["type"] not in ["paired", "unpaired", "grouped"]:
85
        raise ValueError(f"data type must be paired / unpaired / grouped, got {type}.")
86
87
    if data_config["format"] not in ["nifti", "h5"]:
88
        raise ValueError(f"data format must be nifti / h5, got {format}.")
89
90
    assert "dir" in data_config
91
    for mode in ["train", "valid", "test"]:
92
        assert mode in data_config["dir"].keys()
93
        data_dir = data_config["dir"][mode]
94
        if data_dir is None:
95
            logging.warning(f"Data directory for {mode} is not defined.")
96
        if not (isinstance(data_dir, (str, list)) or data_dir is None):
97
            raise ValueError(
98
                f"data_dir for mode {mode} must be string or list of strings,"
99
                f"got {data_dir}."
100
            )
101
102
    # back compatibility support
103
    config = parse_v011(config)
104
105
    # check model
106
    if config["train"]["method"] == "conditional":
107
        if data_config["labeled"] is False:  # unlabeled
108
            raise ValueError(
109
                "For conditional model, data have to be labeled, got unlabeled data."
110
            )
111
112
    return config
113