Passed
Pull Request — main (#688)
by Yunguan
03:21
created

deepreg.config.parser.config_sanity_check()   C

Complexity

Conditions 9

Size

Total Lines 39
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 21
dl 0
loc 39
rs 6.6666
c 0
b 0
f 0
cc 9
nop 1
1
import logging
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
import os
3
from typing import Dict, List, Union
4
5
import yaml
6
7
from deepreg.config.v011 import parse_v011
8
9
10
def update_nested_dict(d: Dict, u: Dict) -> Dict:
11
    """
12
    Merge two dicts.
13
14
    https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth
15
16
    :param d: dict to be overwritten in case of conflicts.
17
    :param u: dict to be merged into d.
18
    :return:
19
    """
20
21
    for k, v in u.items():
22
        if isinstance(v, dict):
23
            d[k] = update_nested_dict(d.get(k, {}), v)
24
        else:
25
            d[k] = v
26
    return d
27
28
29
def load_configs(config_path: Union[str, List[str]]) -> Dict:
30
    """
31
    Load multiple configs and update the nested dictionary.
32
33
    :param config_path: list of paths or one path.
34
    :return: the loaded config
35
    """
36
    if isinstance(config_path, str):
37
        config_path = [config_path]
38
    # replace ~ with user home path
39
    config_path = [os.path.expanduser(x) for x in config_path]
40
    config: Dict = {}
41
    for config_path_i in config_path:
42
        with open(config_path_i) as file:
43
            config_i = yaml.load(file, Loader=yaml.FullLoader)
44
        config = update_nested_dict(d=config, u=config_i)
45
    loaded_config = config_sanity_check(config)
46
47
    if loaded_config != config:
48
        # config got updated
49
        head, tail = os.path.split(config_path[0])
50
        filename = "updated_" + tail
51
        save(config=loaded_config, out_dir=head, filename=filename)
52
        logging.error(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
53
            f"Used config is outdated. "
54
            f"An updated version has been saved at "
55
            f"{os.path.join(head, filename)}"
56
        )
57
58
    return loaded_config
59
60
61
def save(config: dict, out_dir: str, filename: str = "config.yaml"):
62
    """
63
    Save the config into a yaml file.
64
65
    :param config: configuration to be outputed
66
    :param out_dir: directory of the output file
67
    :param filename: name of the output file
68
    """
69
    assert filename.endswith(".yaml")
70
    with open(os.path.join(out_dir, filename), "w+") as f:
71
        f.write(yaml.dump(config))
72
73
74
def config_sanity_check(config: dict) -> dict:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
introduced by
Missing return documentation
Loading history...
75
    """
76
    Check if the given config satisfies the requirements.
77
78
    :param config: entire config.
79
    """
80
81
    # check data
82
    data_config = config["dataset"]
83
84
    if data_config["type"] not in ["paired", "unpaired", "grouped"]:
85
        raise ValueError(f"data type must be paired / unpaired / grouped, got {type}.")
86
87
    if data_config["format"] not in ["nifti", "h5"]:
88
        raise ValueError(f"data format must be nifti / h5, got {format}.")
89
90
    assert "dir" in data_config
91
    for mode in ["train", "valid", "test"]:
92
        assert mode in data_config["dir"].keys()
93
        data_dir = data_config["dir"][mode]
94
        if data_dir is None:
95
            logging.warning(f"Data directory for {mode} is not defined.")
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
96
        if not (isinstance(data_dir, (str, list)) or data_dir is None):
97
            raise ValueError(
98
                f"data_dir for mode {mode} must be string or list of strings,"
99
                f"got {data_dir}."
100
            )
101
102
    # back compatibility support
103
    config = parse_v011(config)
104
105
    # check model
106
    if config["train"]["method"] == "conditional":
107
        if data_config["labeled"] is False:  # unlabeled
108
            raise ValueError(
109
                "For conditional model, data have to be labeled, got unlabeled data."
110
            )
111
112
    return config
113