Passed
Branch main (46851d)
by Yunguan
02:04
created

deepreg.config.parser.config_sanity_check()   A

Complexity

Conditions 3

Size

Total Lines 21
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 21
rs 10
c 0
b 0
f 0
cc 3
nop 1
1
import os
2
from typing import Dict, List, Union
3
4
import yaml
5
6
from deepreg import log
7
from deepreg.config.v011 import parse_v011
8
9
logger = log.get(__name__)
10
11
12
def update_nested_dict(d: Dict, u: Dict) -> Dict:
13
    """
14
    Merge two dicts.
15
16
    https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth
17
18
    :param d: dict to be overwritten in case of conflicts.
19
    :param u: dict to be merged into d.
20
    :return:
21
    """
22
23
    for k, v in u.items():
24
        if isinstance(v, dict):
25
            d[k] = update_nested_dict(d.get(k, {}), v)
26
        else:
27
            d[k] = v
28
    return d
29
30
31
def load_configs(config_path: Union[str, List[str]]) -> Dict:
32
    """
33
    Load multiple configs and update the nested dictionary.
34
35
    :param config_path: list of paths or one path.
36
    :return: the loaded config
37
    """
38
    if isinstance(config_path, str):
39
        config_path = [config_path]
40
    # replace ~ with user home path
41
    config_path = [os.path.expanduser(x) for x in config_path]
42
    config: Dict = {}
43
    for config_path_i in config_path:
44
        with open(config_path_i) as file:
45
            config_i = yaml.load(file, Loader=yaml.FullLoader)
46
        config = update_nested_dict(d=config, u=config_i)
47
    loaded_config = config_sanity_check(config)
48
49
    if loaded_config != config:
50
        # config got updated
51
        head, tail = os.path.split(config_path[0])
52
        filename = "updated_" + tail
53
        save(config=loaded_config, out_dir=head, filename=filename)
54
        logger.error(
55
            "The provided configuration file is outdated. "
56
            "An updated version has been saved at %s.",
57
            os.path.join(head, filename),
58
        )
59
60
    return loaded_config
61
62
63
def save(config: dict, out_dir: str, filename: str = "config.yaml"):
64
    """
65
    Save the config into a yaml file.
66
67
    :param config: configuration to be outputed
68
    :param out_dir: directory of the output file
69
    :param filename: name of the output file
70
    """
71
    assert filename.endswith(".yaml")
72
    with open(os.path.join(out_dir, filename), "w+") as f:
73
        f.write(yaml.dump(config))
74
75
76
def config_sanity_check(config: dict) -> dict:
77
    """
78
    Check if the given config satisfies the requirements.
79
80
    :param config: entire config.
81
    """
82
83
    # check data
84
    data_config = config["dataset"]
85
86
    # back compatibility support
87
    config = parse_v011(config)
88
89
    # check model
90
    if config["train"]["method"] == "conditional":
91
        if data_config["labeled"] is False:  # unlabeled
92
            raise ValueError(
93
                "For conditional model, data have to be labeled, got unlabeled data."
94
            )
95
96
    return config
97