Completed
Push — main ( bbe77b...232bc8 )
by Yunguan
22s queued 13s
created

deepreg.train.train()   C

Complexity

Conditions 6

Size

Total Lines 114
Code Lines 72

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 72
dl 0
loc 114
rs 6.983
c 0
b 0
f 0
cc 6
nop 7

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# coding=utf-8
2
3
"""
4
Module to train a network using init files and a CLI.
5
"""
6
7
import argparse
8
import os
9
from typing import Dict, List, Tuple, Union
10
11
import tensorflow as tf
12
13
import deepreg.config.parser as config_parser
14
import deepreg.model.optimizer as opt
15
from deepreg.callback import build_checkpoint_callback
16
from deepreg.registry import REGISTRY
17
from deepreg.util import build_dataset, build_log_dir
18
19
20
def build_config(
21
    config_path: Union[str, List[str]],
22
    log_dir: str,
23
    exp_name: str,
24
    ckpt_path: str,
25
    max_epochs: int = -1,
26
) -> Tuple[Dict, str, str]:
27
    """
28
    Function to initialise log directories,
29
    assert that checkpointed model is the right
30
    type and to parse the configuration for training.
31
32
    :param config_path: list of str, path to config file
33
    :param log_dir: path of the log directory
34
    :param exp_name: name of the experiment
35
    :param ckpt_path: path where model is stored.
36
    :param max_epochs: if max_epochs > 0, use it to overwrite the configuration
37
    :return: - config: a dictionary saving configuration
38
             - exp_name: the path of directory to save logs
39
    """
40
41
    # init log directory
42
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
43
44
    # load config
45
    config = config_parser.load_configs(config_path)
46
47
    # replace the ~ with user home path
48
    ckpt_path = os.path.expanduser(ckpt_path)
49
50
    # overwrite epochs and save_period if necessary
51
    if max_epochs > 0:
52
        config["train"]["epochs"] = max_epochs
53
        config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"])
54
55
    # backup config
56
    config_parser.save(config=config, out_dir=log_dir)
57
58
    return config, log_dir, ckpt_path
59
60
61
def train(
62
    gpu: str,
63
    config_path: Union[str, List[str]],
64
    gpu_allow_growth: bool,
65
    ckpt_path: str,
66
    exp_name: str = "",
67
    log_dir: str = "logs",
68
    max_epochs: int = -1,
69
):
70
    """
71
    Function to train a model.
72
73
    :param gpu: which local gpu to use to train.
74
    :param config_path: path to configuration set up.
75
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
76
    :param ckpt_path: where to store training checkpoints.
77
    :param log_dir: path of the log directory.
78
    :param exp_name: experiment name.
79
    :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration.
80
    """
81
    # set env variables
82
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
83
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false"
84
85
    # load config
86
    config, log_dir, ckpt_path = build_config(
87
        config_path=config_path,
88
        log_dir=log_dir,
89
        exp_name=exp_name,
90
        ckpt_path=ckpt_path,
91
        max_epochs=max_epochs,
92
    )
93
94
    # build dataset
95
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
96
        dataset_config=config["dataset"],
97
        preprocess_config=config["train"]["preprocess"],
98
        mode="train",
99
        training=True,
100
        repeat=True,
101
    )
102
    assert data_loader_train is not None  # train data should not be None
103
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
104
        dataset_config=config["dataset"],
105
        preprocess_config=config["train"]["preprocess"],
106
        mode="valid",
107
        training=False,
108
        repeat=True,
109
    )
110
111
    # use strategy to support multiple GPUs
112
    # the network is mirrored in each GPU so that we can use larger batch size
113
    # https://www.tensorflow.org/guide/distributed_training
114
    # only model, optimizer and metrics need to be defined inside the strategy
115
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
116
    batch_size = config["train"]["preprocess"]["batch_size"]
117
    if num_devices > 1:  # pragma: no cover
118
        strategy = tf.distribute.MirroredStrategy()
119
        if batch_size % num_devices != 0:
120
            raise ValueError(
121
                f"batch size {batch_size} can not be divided evenly "
122
                f"by the number of devices."
123
            )
124
    else:
125
        strategy = tf.distribute.get_strategy()
126
    with strategy.scope():
127
        model: tf.keras.Model = REGISTRY.build_model(
128
            config=dict(
129
                name=config["train"]["method"],
130
                moving_image_size=data_loader_train.moving_image_shape,
131
                fixed_image_size=data_loader_train.fixed_image_shape,
132
                index_size=data_loader_train.num_indices,
133
                labeled=config["dataset"]["labeled"],
134
                batch_size=batch_size,
135
                config=config["train"],
136
            )
137
        )
138
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
139
        model.compile(optimizer=optimizer)
140
        model.plot_model(output_dir=log_dir)
141
142
    # build callbacks
143
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
144
        log_dir=log_dir,
145
        histogram_freq=config["train"]["save_period"],
146
        update_freq=config["train"].get("update_freq", "epoch"),
147
    )
148
    ckpt_callback, initial_epoch = build_checkpoint_callback(
149
        model=model,
150
        dataset=dataset_train,
151
        log_dir=log_dir,
152
        save_period=config["train"]["save_period"],
153
        ckpt_path=ckpt_path,
154
    )
155
    callbacks = [tensorboard_callback, ckpt_callback]
156
157
    # train
158
    # it's necessary to define the steps_per_epoch
159
    # and validation_steps to prevent errors like
160
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
161
    model.fit(
162
        x=dataset_train,
163
        steps_per_epoch=steps_per_epoch_train,
164
        initial_epoch=initial_epoch,
165
        epochs=config["train"]["epochs"],
166
        validation_data=dataset_val,
167
        validation_steps=steps_per_epoch_val,
168
        callbacks=callbacks,
169
    )
170
171
    # close file loaders in data loaders after training
172
    data_loader_train.close()
173
    if data_loader_val is not None:
174
        data_loader_val.close()
175
176
177
def main(args=None):
178
    """
179
    Entry point for train script.
180
181
    :param args: arguments
182
    """
183
184
    parser = argparse.ArgumentParser()
185
186
    parser.add_argument(
187
        "--gpu",
188
        "-g",
189
        help="GPU index for training."
190
        '-g "" for using CPU'
191
        '-g "0" for using GPU 0'
192
        '-g "0,1" for using GPU 0 and 1.',
193
        type=str,
194
        required=True,
195
    )
196
197
    parser.add_argument(
198
        "--gpu_allow_growth",
199
        "-gr",
200
        help="Prevent TensorFlow from reserving all available GPU memory",
201
        default=False,
202
    )
203
204
    parser.add_argument(
205
        "--ckpt_path",
206
        "-k",
207
        help="Path of the saved model checkpoint to load."
208
        "No need to provide if start training from scratch.",
209
        default="",
210
        type=str,
211
        required=False,
212
    )
213
214
    parser.add_argument(
215
        "--log_dir", help="Path of log directory.", default="logs", type=str
216
    )
217
218
    parser.add_argument(
219
        "--exp_name",
220
        "-l",
221
        help="Name of log directory."
222
        "The directory is under log root, e.g. logs/ by default."
223
        "If not provided, a timestamp based folder will be created.",
224
        default="",
225
        type=str,
226
    )
227
228
    parser.add_argument(
229
        "--config_path",
230
        "-c",
231
        help="Path of config, must end with .yaml. Can pass multiple paths.",
232
        type=str,
233
        nargs="+",
234
        required=True,
235
    )
236
237
    parser.add_argument(
238
        "--max_epochs",
239
        help="The maximum number of epochs, -1 means following configuration.",
240
        type=int,
241
        default=-1,
242
    )
243
244
    args = parser.parse_args(args)
245
    train(
246
        gpu=args.gpu,
247
        config_path=args.config_path,
248
        gpu_allow_growth=args.gpu_allow_growth,
249
        ckpt_path=args.ckpt_path,
250
        log_dir=args.log_dir,
251
        exp_name=args.exp_name,
252
        max_epochs=args.max_epochs,
253
    )
254
255
256
if __name__ == "__main__":
257
    main()  # pragma: no cover
258