Passed
Pull Request — main (#724)
by
unknown
01:41
created

deepreg.train.train()   C

Complexity

Conditions 6

Size

Total Lines 119
Code Lines 74

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 74
dl 0
loc 119
rs 6.9175
c 0
b 0
f 0
cc 6
nop 8

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
2
3
"""
4
Module to train a network using init files and a CLI.
5
"""
6
7
import argparse
8
import os
9
from typing import Dict, List, Tuple, Union
10
11
import tensorflow as tf
12
13
import deepreg.config.parser as config_parser
14
import deepreg.model.optimizer as opt
15
from deepreg.callback import build_checkpoint_callback
16
from deepreg.registry import REGISTRY
17
from deepreg.util import build_dataset, build_log_dir
18
19
20
def build_config(
21
    config_path: Union[str, List[str]],
22
    log_dir: str,
23
    exp_name: str,
24
    ckpt_path: str,
25
    max_epochs: int = -1,
26
) -> Tuple[Dict, str, str]:
27
    """
28
    Function to initialise log directories,
29
    assert that checkpointed model is the right
30
    type and to parse the configuration for training.
31
32
    :param config_path: list of str, path to config file
33
    :param log_dir: path of the log directory
34
    :param exp_name: name of the experiment
35
    :param ckpt_path: path where model is stored.
36
    :param max_epochs: if max_epochs > 0, use it to overwrite the configuration
37
    :return: - config: a dictionary saving configuration
38
             - exp_name: the path of directory to save logs
39
    """
40
41
    # init log directory
42
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
43
44
    # load config
45
    config = config_parser.load_configs(config_path)
46
47
    # replace the ~ with user home path
48
    ckpt_path = os.path.expanduser(ckpt_path)
49
50
    # overwrite epochs and save_period if necessary
51
    if max_epochs > 0:
52
        config["train"]["epochs"] = max_epochs
53
        config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"])
54
55
    # backup config
56
    config_parser.save(config=config, out_dir=log_dir)
57
58
    # batch_size in original config corresponds to batch_size per GPU
59
    gpus = tf.config.experimental.list_physical_devices("GPU")
60
    config["train"]["preprocess"]["batch_size"] *= max(len(gpus), 1)
61
62
    return config, log_dir, ckpt_path
63
64
65
def train(
66
    gpu: str,
67
    config_path: Union[str, List[str]],
68
    ckpt_path: str,
69
    num_cpus: int = -1,
70
    gpu_allow_growth: bool = True,
71
    exp_name: str = "",
72
    log_dir: str = "logs",
73
    max_epochs: int = -1,
74
):
75
    """
76
    Function to train a model.
77
78
    :param gpu: which local gpu to use to train.
79
    :param config_path: path to configuration set up.
80
    :param ckpt_path: where to store training checkpoints.
81
    :param num_cpus: number of cpus to be used, -1 means not limited.
82
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
83
    :param log_dir: path of the log directory.
84
    :param exp_name: experiment name.
85
    :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration.
86
    """
87
    # set env variables
88
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
89
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false"
90
    if num_cpus > 0:
91
        # Maximum number of threads to use for OpenMP parallel regions.
92
        os.environ["OMP_NUM_THREADS"] = str(num_cpus)
93
        # Without setting below 2 environment variables, it didn't work for me. Thanks to @cjw85
94
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_cpus)
95
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_cpus)
96
97
    # load config
98
    config, log_dir, ckpt_path = build_config(
99
        config_path=config_path,
100
        log_dir=log_dir,
101
        exp_name=exp_name,
102
        ckpt_path=ckpt_path,
103
        max_epochs=max_epochs,
104
    )
105
106
    # build dataset
107
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
108
        dataset_config=config["dataset"],
109
        preprocess_config=config["train"]["preprocess"],
110
        mode="train",
111
        training=True,
112
        repeat=True,
113
    )
114
    assert data_loader_train is not None  # train data should not be None
115
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
116
        dataset_config=config["dataset"],
117
        preprocess_config=config["train"]["preprocess"],
118
        mode="valid",
119
        training=False,
120
        repeat=True,
121
    )
122
123
    # use strategy to support multiple GPUs
124
    # the network is mirrored in each GPU so that we can use larger batch size
125
    # https://www.tensorflow.org/guide/distributed_training
126
    # only model, optimizer and metrics need to be defined inside the strategy
127
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
128
    if num_devices > 1:
129
        strategy = tf.distribute.MirroredStrategy()  # pragma: no cover
130
    else:
131
        strategy = tf.distribute.get_strategy()
132
    with strategy.scope():
133
        model: tf.keras.Model = REGISTRY.build_model(
134
            config=dict(
135
                name=config["train"]["method"],
136
                moving_image_size=data_loader_train.moving_image_shape,
137
                fixed_image_size=data_loader_train.fixed_image_shape,
138
                index_size=data_loader_train.num_indices,
139
                labeled=config["dataset"]["labeled"],
140
                batch_size=config["train"]["preprocess"]["batch_size"],
141
                config=config["train"],
142
                num_devices=num_devices,
143
            )
144
        )
145
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
146
147
    # compile
148
    model.compile(optimizer=optimizer)
149
    model.plot_model(output_dir=log_dir)
150
151
    # build callbacks
152
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
153
        log_dir=log_dir,
154
        histogram_freq=config["train"]["save_period"],
155
        update_freq=config["train"].get("update_freq", "epoch"),
156
    )
157
    ckpt_callback, initial_epoch = build_checkpoint_callback(
158
        model=model,
159
        dataset=dataset_train,
160
        log_dir=log_dir,
161
        save_period=config["train"]["save_period"],
162
        ckpt_path=ckpt_path,
163
    )
164
    callbacks = [tensorboard_callback, ckpt_callback]
165
166
    # train
167
    # it's necessary to define the steps_per_epoch
168
    # and validation_steps to prevent errors like
169
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
170
    model.fit(
171
        x=dataset_train,
172
        steps_per_epoch=steps_per_epoch_train,
173
        initial_epoch=initial_epoch,
174
        epochs=config["train"]["epochs"],
175
        validation_data=dataset_val,
176
        validation_steps=steps_per_epoch_val,
177
        callbacks=callbacks,
178
    )
179
180
    # close file loaders in data loaders after training
181
    data_loader_train.close()
182
    if data_loader_val is not None:
183
        data_loader_val.close()
184
185
186
def main(args=None):
187
    """
188
    Entry point for train script.
189
190
    :param args: arguments
191
    """
192
193
    parser = argparse.ArgumentParser()
194
195
    parser.add_argument(
196
        "--gpu",
197
        "-g",
198
        help="GPU index for training."
199
        '-g "" for using CPU'
200
        '-g "0" for using GPU 0'
201
        '-g "0,1" for using GPU 0 and 1.',
202
        type=str,
203
        required=True,
204
    )
205
206
    parser.add_argument(
207
        "--gpu_allow_growth",
208
        "-gr",
209
        help="Prevent TensorFlow from reserving all available GPU memory",
210
        default=False,
211
    )
212
213
    parser.add_argument(
214
        "--num_cpus",
215
        help="Number of CPUs to be used, -1 means unlimited.",
216
        type=int,
217
        default=-1,
218
    )
219
220
    parser.add_argument(
221
        "--ckpt_path",
222
        "-k",
223
        help="Path of the saved model checkpoint to load."
224
        "No need to provide if start training from scratch.",
225
        default="",
226
        type=str,
227
        required=False,
228
    )
229
230
    parser.add_argument(
231
        "--log_dir", help="Path of log directory.", default="logs", type=str
232
    )
233
234
    parser.add_argument(
235
        "--exp_name",
236
        "-l",
237
        help="Name of log directory."
238
        "The directory is under log root, e.g. logs/ by default."
239
        "If not provided, a timestamp based folder will be created.",
240
        default="",
241
        type=str,
242
    )
243
244
    parser.add_argument(
245
        "--config_path",
246
        "-c",
247
        help="Path of config, must end with .yaml. Can pass multiple paths.",
248
        type=str,
249
        nargs="+",
250
        required=True,
251
    )
252
253
    parser.add_argument(
254
        "--max_epochs",
255
        help="The maximum number of epochs, -1 means following configuration.",
256
        type=int,
257
        default=-1,
258
    )
259
260
    args = parser.parse_args(args)
261
    train(
262
        gpu=args.gpu,
263
        config_path=args.config_path,
264
        num_cpus=args.num_cpus,
265
        gpu_allow_growth=args.gpu_allow_growth,
266
        ckpt_path=args.ckpt_path,
267
        log_dir=args.log_dir,
268
        exp_name=args.exp_name,
269
        max_epochs=args.max_epochs,
270
    )
271
272
273
if __name__ == "__main__":
274
    main()  # pragma: no cover
275