Passed
Pull Request — main (#756)
by
unknown
01:50
created

deepreg.config.parser.save()   A

Complexity

Conditions 2

Size

Total Lines 11
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 4
dl 0
loc 11
rs 10
c 0
b 0
f 0
cc 2
nop 3
1
import os
2
from typing import Dict, List, Union
3
4
import wandb
5
import yaml
6
7
from deepreg import log
8
from deepreg.config.v011 import parse_v011
9
10
logger = log.get(__name__)
11
12
13
def update_nested_dict(d: Dict, u: Dict) -> Dict:
14
    """
15
    Merge two dicts.
16
17
    https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth
18
19
    :param d: dict to be overwritten in case of conflicts.
20
    :param u: dict to be merged into d.
21
    :return:
22
    """
23
24
    for k, v in u.items():
25
        if isinstance(v, dict):
26
            d[k] = update_nested_dict(d.get(k, {}), v)
27
        else:
28
            d[k] = v
29
    return d
30
31
32
def load_configs(config_path: Union[str, List[str]]) -> Dict:
33
    """
34
    Load multiple configs and update the nested dictionary.
35
36
    :param config_path: list of paths or one path.
37
    :return: the loaded config
38
    """
39
    if isinstance(config_path, str):
40
        config_path = [config_path]
41
    # replace ~ with user home path
42
    config_path = [os.path.expanduser(x) for x in config_path]
43
    config: Dict = {}
44
    for config_path_i in config_path:
45
        with open(config_path_i) as file:
46
            config_i = yaml.load(file, Loader=yaml.FullLoader)
47
        config = update_nested_dict(d=config, u=config_i)
48
    loaded_config = config_sanity_check(config)
49
50
    if loaded_config != config:
51
        # config got updated
52
        head, tail = os.path.split(config_path[0])
53
        filename = "updated_" + tail
54
        save(config=loaded_config, out_dir=head, filename=filename)
55
        logger.error(
56
            "The provided configuration file is outdated. "
57
            "An updated version has been saved at %s.",
58
            os.path.join(head, filename),
59
        )
60
61
    return loaded_config
62
63
64
def save(config: dict, out_dir: str, filename: str = "config.yaml"):
65
    """
66
    Save the config into a yaml file.
67
68
    :param config: configuration to be outputed
69
    :param out_dir: directory of the output file
70
    :param filename: name of the output file
71
    """
72
    assert filename.endswith(".yaml")
73
    with open(os.path.join(out_dir, filename), "w+") as f:
74
        f.write(yaml.dump(config))
75
76
77
def config_sanity_check(config: dict) -> dict:
78
    """
79
    Check if the given config satisfies the requirements.
80
81
    :param config: entire config.
82
    """
83
84
    # check data
85
    data_config = config["dataset"]
86
87
    # back compatibility support
88
    config = parse_v011(config)
89
90
    # check model
91
    if config["train"]["method"] == "conditional":
92
        if data_config["labeled"] is False:  # unlabeled
93
            raise ValueError(
94
                "For conditional model, data have to be labeled, got unlabeled data."
95
            )
96
97
    return config
98
99
100
def has_wandb_callback(config: dict):
101
    """
102
    Function that checks if a given config has W&B
103
    keys.
104
    :param config: config dictionary with parameters for run.
105
    :return: bool, whether wandb key in config.
106
    """
107
    if "wandb" in config:
108
        return True
109
    return False
110
111
112
def instantiate_wandb_run(config: dict):
113
    """
114
    From a config dictionary with wandb keys,
115
    run wandb.init to log training.
116
117
    :param config: config dictionary with parameters for run.
118
    :return: N/A.
119
    """
120
    if "init" not in config["wandb"]:
121
        logging.error("No init field in config. Creating empty init.")
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable logging does not seem to be defined.
Loading history...
122
        wandb.init()
123
    else:
124
        wandb_init = config["wandb"]["init"]
125
        wandb.init(**wandb_init)
126
127
128
def instantiate_wandb_callback(config: dict):
129
    """
130
    From a config dictionary with wandb keys,
131
    generate a run callback to use during training.
132
133
    :param config: config dictionary with parameters for run.
134
    :return: tf.keras.callback, see W&B docs for more info.
135
    """
136
    # If the callback key does not exist, initialise an
137
    # empty run.
138
    if "callback" not in config["wandb"]:
139
        logging.error("No callback field in config. Creating empty callback.")
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable logging does not seem to be defined.
Loading history...
140
        wandb_callback = wandb.keras.WandbCallback()
141
    # Get sub dict that contains the wandb params
142
    else:
143
        wandb_dict = config["wandb"]["callback"]
144
        wandb_callback = wandb.keras.WandbCallback(**wandb_dict)
145
    return wandb_callback
146