Passed
Pull Request — main (#724)
by
unknown
01:41
created

deepreg.train.main()   B

Complexity

Conditions 1

Size

Total Lines 84
Code Lines 55

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 55
dl 0
loc 84
rs 8.4727
c 0
b 0
f 0
cc 1
nop 1

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# coding=utf-8
2
3
"""
4
Module to train a network using init files and a CLI.
5
"""
6
7
import argparse
8
import os
9
from typing import Dict, List, Tuple, Union
10
11
import tensorflow as tf
12
13
import deepreg.config.parser as config_parser
14
import deepreg.model.optimizer as opt
15
from deepreg.callback import build_checkpoint_callback
16
from deepreg.registry import REGISTRY
17
from deepreg.util import build_dataset, build_log_dir
18
19
20
def build_config(
21
    config_path: Union[str, List[str]],
22
    log_dir: str,
23
    exp_name: str,
24
    ckpt_path: str,
25
    max_epochs: int = -1,
26
) -> Tuple[Dict, str, str]:
27
    """
28
    Function to initialise log directories,
29
    assert that checkpointed model is the right
30
    type and to parse the configuration for training.
31
32
    :param config_path: list of str, path to config file
33
    :param log_dir: path of the log directory
34
    :param exp_name: name of the experiment
35
    :param ckpt_path: path where model is stored.
36
    :param max_epochs: if max_epochs > 0, use it to overwrite the configuration
37
    :return: - config: a dictionary saving configuration
38
             - exp_name: the path of directory to save logs
39
    """
40
41
    # init log directory
42
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
43
44
    # load config
45
    config = config_parser.load_configs(config_path)
46
47
    # replace the ~ with user home path
48
    ckpt_path = os.path.expanduser(ckpt_path)
49
50
    # overwrite epochs and save_period if necessary
51
    if max_epochs > 0:
52
        config["train"]["epochs"] = max_epochs
53
        config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"])
54
55
    # backup config
56
    config_parser.save(config=config, out_dir=log_dir)
57
58
    # batch_size in original config corresponds to batch_size per GPU
59
    gpus = tf.config.experimental.list_physical_devices("GPU")
60
    config["train"]["preprocess"]["batch_size"] *= max(len(gpus), 1)
61
62
    return config, log_dir, ckpt_path
63
64
65
def train(
66
    gpu: str,
67
    config_path: Union[str, List[str]],
68
    ckpt_path: str,
69
    num_cpus: int = -1,
70
    gpu_allow_growth: bool = True,
71
    exp_name: str = "",
72
    log_dir: str = "logs",
73
    max_epochs: int = -1,
74
):
75
    """
76
    Function to train a model.
77
78
    :param gpu: which local gpu to use to train.
79
    :param config_path: path to configuration set up.
80
    :param ckpt_path: where to store training checkpoints.
81
    :param num_cpus: number of cpus to be used, -1 means not limited.
82
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
83
    :param log_dir: path of the log directory.
84
    :param exp_name: experiment name.
85
    :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration.
86
    """
87
    # set env variables
88
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
89
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false"
90
    if num_cpus > 0:
91
        # Maximum number of threads to use for OpenMP parallel regions.
92
        os.environ["OMP_NUM_THREADS"] = str(num_cpus)
93
        # Without setting below 2 environment variables, it didn't work for me. Thanks to @cjw85
94
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_cpus)
95
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_cpus)
96
97
    # load config
98
    config, log_dir, ckpt_path = build_config(
99
        config_path=config_path,
100
        log_dir=log_dir,
101
        exp_name=exp_name,
102
        ckpt_path=ckpt_path,
103
        max_epochs=max_epochs,
104
    )
105
106
    # build dataset
107
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
108
        dataset_config=config["dataset"],
109
        preprocess_config=config["train"]["preprocess"],
110
        mode="train",
111
        training=True,
112
        repeat=True,
113
    )
114
    assert data_loader_train is not None  # train data should not be None
115
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
116
        dataset_config=config["dataset"],
117
        preprocess_config=config["train"]["preprocess"],
118
        mode="valid",
119
        training=False,
120
        repeat=True,
121
    )
122
123
    # use strategy to support multiple GPUs
124
    # the network is mirrored in each GPU so that we can use larger batch size
125
    # https://www.tensorflow.org/guide/distributed_training
126
    # only model, optimizer and metrics need to be defined inside the strategy
127
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
128
    if num_devices > 1:
129
        strategy = tf.distribute.MirroredStrategy()  # pragma: no cover
130
    else:
131
        strategy = tf.distribute.get_strategy()
132
    with strategy.scope():
133
        model: tf.keras.Model = REGISTRY.build_model(
134
            config=dict(
135
                name=config["train"]["method"],
136
                moving_image_size=data_loader_train.moving_image_shape,
137
                fixed_image_size=data_loader_train.fixed_image_shape,
138
                index_size=data_loader_train.num_indices,
139
                labeled=config["dataset"]["labeled"],
140
                batch_size=config["train"]["preprocess"]["batch_size"],
141
                config=config["train"],
142
                num_devices=num_devices,
143
            )
144
        )
145
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
146
147
    # compile
148
    model.compile(optimizer=optimizer)
149
    model.plot_model(output_dir=log_dir)
150
151
    # build callbacks
152
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
153
        log_dir=log_dir,
154
        histogram_freq=config["train"]["save_period"],
155
        update_freq=config["train"].get("update_freq", "epoch"),
156
    )
157
    ckpt_callback, initial_epoch = build_checkpoint_callback(
158
        model=model,
159
        dataset=dataset_train,
160
        log_dir=log_dir,
161
        save_period=config["train"]["save_period"],
162
        ckpt_path=ckpt_path,
163
    )
164
    callbacks = [tensorboard_callback, ckpt_callback]
165
166
    # train
167
    # it's necessary to define the steps_per_epoch
168
    # and validation_steps to prevent errors like
169
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
170
    model.fit(
171
        x=dataset_train,
172
        steps_per_epoch=steps_per_epoch_train,
173
        initial_epoch=initial_epoch,
174
        epochs=config["train"]["epochs"],
175
        validation_data=dataset_val,
176
        validation_steps=steps_per_epoch_val,
177
        callbacks=callbacks,
178
    )
179
180
    # close file loaders in data loaders after training
181
    data_loader_train.close()
182
    if data_loader_val is not None:
183
        data_loader_val.close()
184
185
186
def main(args=None):
187
    """
188
    Entry point for train script.
189
190
    :param args: arguments
191
    """
192
193
    parser = argparse.ArgumentParser()
194
195
    parser.add_argument(
196
        "--gpu",
197
        "-g",
198
        help="GPU index for training."
199
        '-g "" for using CPU'
200
        '-g "0" for using GPU 0'
201
        '-g "0,1" for using GPU 0 and 1.',
202
        type=str,
203
        required=True,
204
    )
205
206
    parser.add_argument(
207
        "--gpu_allow_growth",
208
        "-gr",
209
        help="Prevent TensorFlow from reserving all available GPU memory",
210
        default=False,
211
    )
212
213
    parser.add_argument(
214
        "--num_cpus",
215
        help="Number of CPUs to be used, -1 means unlimited.",
216
        type=int,
217
        default=-1,
218
    )
219
220
    parser.add_argument(
221
        "--ckpt_path",
222
        "-k",
223
        help="Path of the saved model checkpoint to load."
224
        "No need to provide if start training from scratch.",
225
        default="",
226
        type=str,
227
        required=False,
228
    )
229
230
    parser.add_argument(
231
        "--log_dir", help="Path of log directory.", default="logs", type=str
232
    )
233
234
    parser.add_argument(
235
        "--exp_name",
236
        "-l",
237
        help="Name of log directory."
238
        "The directory is under log root, e.g. logs/ by default."
239
        "If not provided, a timestamp based folder will be created.",
240
        default="",
241
        type=str,
242
    )
243
244
    parser.add_argument(
245
        "--config_path",
246
        "-c",
247
        help="Path of config, must end with .yaml. Can pass multiple paths.",
248
        type=str,
249
        nargs="+",
250
        required=True,
251
    )
252
253
    parser.add_argument(
254
        "--max_epochs",
255
        help="The maximum number of epochs, -1 means following configuration.",
256
        type=int,
257
        default=-1,
258
    )
259
260
    args = parser.parse_args(args)
261
    train(
262
        gpu=args.gpu,
263
        config_path=args.config_path,
264
        num_cpus=args.num_cpus,
265
        gpu_allow_growth=args.gpu_allow_growth,
266
        ckpt_path=args.ckpt_path,
267
        log_dir=args.log_dir,
268
        exp_name=args.exp_name,
269
        max_epochs=args.max_epochs,
270
    )
271
272
273
if __name__ == "__main__":
274
    main()  # pragma: no cover
275