Passed
Pull Request — main (#746)
by Yunguan
01:24
created

deepreg.config.parser.save()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 11
rs 10
c 0
b 0
f 0
cc 2
nop 3
1
import os
2
from typing import Dict, List, Union
3
4
import yaml
5
6
from deepreg import log
7
from deepreg.config.v011 import parse_v011
8
9
logger = log.get(__name__)
10
11
12
def update_nested_dict(d: Dict, u: Dict) -> Dict:
13
    """
14
    Merge two dicts.
15
16
    https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth
17
18
    :param d: dict to be overwritten in case of conflicts.
19
    :param u: dict to be merged into d.
20
    :return:
21
    """
22
23
    for k, v in u.items():
24
        if isinstance(v, dict):
25
            d[k] = update_nested_dict(d.get(k, {}), v)
26
        else:
27
            d[k] = v
28
    return d
29
30
31
def load_configs(config_path: Union[str, List[str]]) -> Dict:
32
    """
33
    Load multiple configs and update the nested dictionary.
34
35
    :param config_path: list of paths or one path.
36
    :return: the loaded config
37
    """
38
    if isinstance(config_path, str):
39
        config_path = [config_path]
40
    # replace ~ with user home path
41
    config_path = [os.path.expanduser(x) for x in config_path]
42
    config: Dict = {}
43
    for config_path_i in config_path:
44
        with open(config_path_i) as file:
45
            config_i = yaml.load(file, Loader=yaml.FullLoader)
46
        config = update_nested_dict(d=config, u=config_i)
47
    loaded_config = config_sanity_check(config)
48
49
    if loaded_config != config:
50
        # config got updated
51
        head, tail = os.path.split(config_path[0])
52
        filename = "updated_" + tail
53
        save(config=loaded_config, out_dir=head, filename=filename)
54
        logger.error(
55
            "Used config is outdated. An updated version has been saved at %s.",
56
            os.path.join(head, filename),
57
        )
58
59
    return loaded_config
60
61
62
def save(config: dict, out_dir: str, filename: str = "config.yaml"):
63
    """
64
    Save the config into a yaml file.
65
66
    :param config: configuration to be outputed
67
    :param out_dir: directory of the output file
68
    :param filename: name of the output file
69
    """
70
    assert filename.endswith(".yaml")
71
    with open(os.path.join(out_dir, filename), "w+") as f:
72
        f.write(yaml.dump(config))
73
74
75
def config_sanity_check(config: dict) -> dict:
76
    """
77
    Check if the given config satisfies the requirements.
78
79
    :param config: entire config.
80
    """
81
82
    # check data
83
    data_config = config["dataset"]
84
85
    # back compatibility support
86
    config = parse_v011(config)
87
88
    # check model
89
    if config["train"]["method"] == "conditional":
90
        if data_config["labeled"] is False:  # unlabeled
91
            raise ValueError(
92
                "For conditional model, data have to be labeled, got unlabeled data."
93
            )
94
95
    return config
96