Passed
Pull Request — main (#662)
by
unknown
05:24 queued 01:54
created

deepreg.config.parser.config_sanity_check()   D

Complexity

Conditions 13

Size

Total Lines 52
Code Lines 30

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 30
dl 0
loc 52
rs 4.2
c 0
b 0
f 0
cc 13
nop 1

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like deepreg.config.parser.config_sanity_check() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import collections.abc
0 ignored issues
show
introduced by
Missing module docstring
Loading history...
2
import logging
3
import os
4
5
import yaml
6
7
from deepreg.config.v011 import parse_v011
8
9
10
def update_nested_dict(d: dict, u: dict) -> dict:
11
    """
12
    Merge two dicts.
13
14
    https://stackoverflow.com/questions/3232943/update-value-of-a-nested-dictionary-of-varying-depth
15
16
    :param d: dict to be overwritten in case of conflicts.
17
    :param u: dict to be merged into d.
18
    :return:
19
    """
20
21
    for k, v in u.items():
22
        if isinstance(v, collections.Mapping):
23
            d[k] = update_nested_dict(d.get(k, {}), v)
24
        else:
25
            d[k] = v
26
    return d
27
28
29
def load_configs(config_path: (str, list)) -> dict:
30
    """
31
    Load multiple configs and update the nested dictionary.
32
33
    :param config_path: list of paths or one path.
34
    :return:
35
    """
36
    if isinstance(config_path, str):
37
        config_path = [config_path]
38
    # replace ~ with user home path
39
    config_path = list(map(os.path.expanduser, config_path))
40
    config = dict()
41
    for config_path_i in config_path:
42
        with open(config_path_i) as file:
43
            config_i = yaml.load(file, Loader=yaml.FullLoader)
44
        config = update_nested_dict(d=config, u=config_i)
45
    loaded_config = config_sanity_check(config)
46
47
    if loaded_config != config:
48
        # config got updated
49
        head, tail = os.path.split(config_path[0])
50
        filename = "updated_" + tail
51
        save(config=loaded_config, out_dir=head, filename=filename)
52
        logging.error(
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
53
            f"Used config is outdated. "
54
            f"An updated version has been saved at "
55
            f"{os.path.join(head, filename)}"
56
        )
57
58
    return loaded_config
59
60
61
def save(config: dict, out_dir: str, filename: str = "config.yaml"):
62
    """
63
    Save the config into a yaml file.
64
65
    :param config: configuration to be outputed
66
    :param out_dir: directory of the output file
67
    :param filename: name of the output file
68
    """
69
    assert filename.endswith(".yaml")
70
    with open(os.path.join(out_dir, filename), "w+") as f:
71
        f.write(yaml.dump(config))
72
73
74
def config_sanity_check(config: dict) -> dict:
0 ignored issues
show
introduced by
"ValueError" not documented as being raised
Loading history...
introduced by
Missing return documentation
Loading history...
75
    """
76
    Check if the given config satisfies the requirements.
77
78
    :param config: entire config.
79
    """
80
81
    # check data
82
    data_config = config["dataset"]
83
84
    if data_config["type"] not in ["paired", "unpaired", "grouped"]:
85
        raise ValueError(f"data type must be paired / unpaired / grouped, got {type}.")
86
87
    if data_config["format"] not in ["nifti", "h5"]:
88
        raise ValueError(f"data format must be nifti / h5, got {format}.")
89
90
    assert "dir" in data_config
91
    for mode in ["train", "valid", "test"]:
92
        assert mode in data_config["dir"].keys()
93
        data_dir = data_config["dir"][mode]
94
        if data_dir is None:
95
            logging.warning(f"Data directory for {mode} is not defined.")
0 ignored issues
show
introduced by
Use lazy % formatting in logging functions
Loading history...
96
        if not (isinstance(data_dir, (str, list)) or data_dir is None):
97
            raise ValueError(
98
                f"data_dir for mode {mode} must be string or list of strings,"
99
                f"got {data_dir}."
100
            )
101
102
    # back compatibility support
103
    config = parse_v011(config)
104
105
    # check model
106
    if config["train"]["method"] == "conditional":
107
        if data_config["labeled"] is False:  # unlabeled
108
            raise ValueError(
109
                "For conditional model, data have to be labeled, got unlabeled data."
110
            )
111
112
    # loss weights should >= 0
113
    for name in ["image", "label", "regularization"]:
114
        loss_config = config["train"]["loss"][name]
115
        if not isinstance(loss_config, list):
116
            loss_config = [loss_config]
117
118
        for loss_i in loss_config:
119
            loss_weight = loss_i["weight"]
120
            if loss_weight <= 0:
121
                logging.warning(
122
                    "The %s loss weight %.2f is not positive.", name, loss_weight
123
                )
124
125
    return config
126