deepreg.train.train()   C
last analyzed

Complexity

Conditions 8

Size

Total Lines 133
Code Lines 82

Duplication

Lines 13
Ratio 9.77 %

Importance

Changes 0
Metric Value
eloc 82
dl 13
loc 133
rs 5.7224
c 0
b 0
f 0
cc 8
nop 8

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
2
3
"""
4
Module to train a network using init files and a CLI.
5
"""
6
7
import argparse
8
import os
9
from typing import Dict, List, Tuple, Union
10
11
import tensorflow as tf
12
13
import deepreg.config.parser as config_parser
14
import deepreg.model.optimizer as opt
15
from deepreg import log
16
from deepreg.callback import build_checkpoint_callback
17
from deepreg.registry import REGISTRY
18
from deepreg.util import build_dataset, build_log_dir
19
20
logger = log.get(__name__)
21
22
23
def build_config(
24
    config_path: Union[str, List[str]],
25
    log_dir: str,
26
    exp_name: str,
27
    ckpt_path: str,
28
    max_epochs: int = -1,
29
) -> Tuple[Dict, str, str]:
30
    """
31
    Function to initialise log directories,
32
    assert that checkpointed model is the right
33
    type and to parse the configuration for training.
34
35
    :param config_path: list of str, path to config file
36
    :param log_dir: path of the log directory
37
    :param exp_name: name of the experiment
38
    :param ckpt_path: path where model is stored.
39
    :param max_epochs: if max_epochs > 0, use it to overwrite the configuration
40
    :return: - config: a dictionary saving configuration
41
             - exp_name: the path of directory to save logs
42
    """
43
44
    # init log directory
45
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
46
47
    # load config
48
    config = config_parser.load_configs(config_path)
49
50
    # replace the ~ with user home path
51
    ckpt_path = os.path.expanduser(ckpt_path)
52
53
    # overwrite epochs and save_period if necessary
54
    if max_epochs > 0:
55
        config["train"]["epochs"] = max_epochs
56
        config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"])
57
58
    # backup config
59
    config_parser.save(config=config, out_dir=log_dir)
60
61
    return config, log_dir, ckpt_path
62
63
64
def train(
65
    gpu: str,
66
    config_path: Union[str, List[str]],
67
    ckpt_path: str,
68
    num_workers: int = 1,
69
    gpu_allow_growth: bool = True,
70
    exp_name: str = "",
71
    log_dir: str = "logs",
72
    max_epochs: int = -1,
73
):
74
    """
75
    Function to train a model.
76
77
    :param gpu: which local gpu to use to train.
78
    :param config_path: path to configuration set up.
79
    :param ckpt_path: where to store training checkpoints.
80
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
81
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
82
    :param log_dir: path of the log directory.
83
    :param exp_name: experiment name.
84
    :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration.
85
    """
86
    # set env variables
87
    if gpu is not None:
88
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
89
        os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = (
90
            "true" if gpu_allow_growth else "false"
91
        )
92 View Code Duplication
    if num_workers <= 0:  # pragma: no cover
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
93
        logger.info(
94
            "Limiting CPU usage by setting environment variables "
95
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
96
            "This may slow down the training. "
97
            "Please use --num_workers flag to modify the behavior. "
98
            "Setting to 0 or negative values will remove the limitation.",
99
            num_workers,
100
        )
101
        # limit CPU usage
102
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
103
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
104
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
105
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)
106
107
    # load config
108
    config, log_dir, ckpt_path = build_config(
109
        config_path=config_path,
110
        log_dir=log_dir,
111
        exp_name=exp_name,
112
        ckpt_path=ckpt_path,
113
        max_epochs=max_epochs,
114
    )
115
116
    # build dataset
117
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
118
        dataset_config=config["dataset"],
119
        preprocess_config=config["train"]["preprocess"],
120
        split="train",
121
        training=True,
122
        repeat=True,
123
    )
124
    assert data_loader_train is not None  # train data should not be None
125
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
126
        dataset_config=config["dataset"],
127
        preprocess_config=config["train"]["preprocess"],
128
        split="valid",
129
        training=False,
130
        repeat=True,
131
    )
132
133
    # use strategy to support multiple GPUs
134
    # the network is mirrored in each GPU so that we can use larger batch size
135
    # https://www.tensorflow.org/guide/distributed_training
136
    # only model, optimizer and metrics need to be defined inside the strategy
137
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
138
    batch_size = config["train"]["preprocess"]["batch_size"]
139
    if num_devices > 1:  # pragma: no cover
140
        strategy = tf.distribute.MirroredStrategy()
141
        if batch_size % num_devices != 0:
142
            raise ValueError(
143
                f"batch size {batch_size} can not be divided evenly "
144
                f"by the number of devices."
145
            )
146
    else:
147
        strategy = tf.distribute.get_strategy()
148
    with strategy.scope():
149
        model: tf.keras.Model = REGISTRY.build_model(
150
            config=dict(
151
                name=config["train"]["method"],
152
                moving_image_size=data_loader_train.moving_image_shape,
153
                fixed_image_size=data_loader_train.fixed_image_shape,
154
                index_size=data_loader_train.num_indices,
155
                labeled=config["dataset"]["train"]["labeled"],
156
                batch_size=batch_size,
157
                config=config["train"],
158
            )
159
        )
160
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
161
        model.compile(optimizer=optimizer)
162
        model.plot_model(output_dir=log_dir)
163
164
    # build callbacks
165
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
166
        log_dir=log_dir,
167
        histogram_freq=config["train"]["save_period"],
168
        update_freq=config["train"].get("update_freq", "epoch"),
169
    )
170
    ckpt_callback, initial_epoch = build_checkpoint_callback(
171
        model=model,
172
        dataset=dataset_train,
173
        log_dir=log_dir,
174
        save_period=config["train"]["save_period"],
175
        ckpt_path=ckpt_path,
176
    )
177
    callbacks = [tensorboard_callback, ckpt_callback]
178
179
    # train
180
    # it's necessary to define the steps_per_epoch
181
    # and validation_steps to prevent errors like
182
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
183
    model.fit(
184
        x=dataset_train,
185
        steps_per_epoch=steps_per_epoch_train,
186
        initial_epoch=initial_epoch,
187
        epochs=config["train"]["epochs"],
188
        validation_data=dataset_val,
189
        validation_steps=steps_per_epoch_val,
190
        callbacks=callbacks,
191
    )
192
193
    # close file loaders in data loaders after training
194
    data_loader_train.close()
195
    if data_loader_val is not None:
196
        data_loader_val.close()
197
198
199
def main(args=None):
200
    """
201
    Entry point for train script.
202
203
    :param args: arguments
204
    """
205
206
    parser = argparse.ArgumentParser()
207
208
    parser.add_argument(
209
        "--gpu",
210
        "-g",
211
        help="GPU index for training."
212
        "-g for using GPU remotely"
213
        '-g "" for using CPU'
214
        '-g "0" for using GPU 0'
215
        '-g "0,1" for using GPU 0 and 1.',
216
        type=str,
217
        required=False,
218
    )
219
220
    parser.add_argument(
221
        "--gpu_allow_growth",
222
        "-gr",
223
        help="Prevent TensorFlow from reserving all available GPU memory",
224
        default=False,
225
    )
226
227
    parser.add_argument(
228
        "--num_workers",
229
        help="Number of CPUs to be used, <= 0 means unlimited.",
230
        type=int,
231
        default=1,
232
    )
233
234
    parser.add_argument(
235
        "--ckpt_path",
236
        "-k",
237
        help="Path of the saved model checkpoint to load."
238
        "No need to provide if start training from scratch.",
239
        default="",
240
        type=str,
241
        required=False,
242
    )
243
244
    parser.add_argument(
245
        "--log_dir", help="Path of log directory.", default="logs", type=str
246
    )
247
248
    parser.add_argument(
249
        "--exp_name",
250
        "-l",
251
        help="Name of log directory."
252
        "The directory is under log root, e.g. logs/ by default."
253
        "If not provided, a timestamp based folder will be created.",
254
        default="",
255
        type=str,
256
    )
257
258
    parser.add_argument(
259
        "--config_path",
260
        "-c",
261
        help="Path of config, must end with .yaml. Can pass multiple paths.",
262
        type=str,
263
        nargs="+",
264
        required=True,
265
    )
266
267
    parser.add_argument(
268
        "--max_epochs",
269
        help="The maximum number of epochs, -1 means following configuration.",
270
        type=int,
271
        default=-1,
272
    )
273
274
    args = parser.parse_args(args)
275
    train(
276
        gpu=args.gpu,
277
        config_path=args.config_path,
278
        num_workers=args.num_workers,
279
        gpu_allow_growth=args.gpu_allow_growth,
280
        ckpt_path=args.ckpt_path,
281
        log_dir=args.log_dir,
282
        exp_name=args.exp_name,
283
        max_epochs=args.max_epochs,
284
    )
285
286
287
if __name__ == "__main__":
288
    main()  # pragma: no cover
289