Issues (32)

deepreg/train.py (1 issue)

1
# coding=utf-8
2
3
"""
4
Module to train a network using init files and a CLI.
5
"""
6
7
import argparse
8
import os
9
from typing import Dict, List, Tuple, Union
10
11
import tensorflow as tf
12
13
import deepreg.config.parser as config_parser
14
import deepreg.model.optimizer as opt
15
from deepreg import log
16
from deepreg.callback import build_checkpoint_callback
17
from deepreg.registry import REGISTRY
18
from deepreg.util import build_dataset, build_log_dir
19
20
logger = log.get(__name__)
21
22
23
def build_config(
24
    config_path: Union[str, List[str]],
25
    log_dir: str,
26
    exp_name: str,
27
    ckpt_path: str,
28
    max_epochs: int = -1,
29
) -> Tuple[Dict, str, str]:
30
    """
31
    Function to initialise log directories,
32
    assert that checkpointed model is the right
33
    type and to parse the configuration for training.
34
35
    :param config_path: list of str, path to config file
36
    :param log_dir: path of the log directory
37
    :param exp_name: name of the experiment
38
    :param ckpt_path: path where model is stored.
39
    :param max_epochs: if max_epochs > 0, use it to overwrite the configuration
40
    :return: - config: a dictionary saving configuration
41
             - exp_name: the path of directory to save logs
42
    """
43
44
    # init log directory
45
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
46
47
    # load config
48
    config = config_parser.load_configs(config_path)
49
50
    # replace the ~ with user home path
51
    ckpt_path = os.path.expanduser(ckpt_path)
52
53
    # overwrite epochs and save_period if necessary
54
    if max_epochs > 0:
55
        config["train"]["epochs"] = max_epochs
56
        config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"])
57
58
    # backup config
59
    config_parser.save(config=config, out_dir=log_dir)
60
61
    return config, log_dir, ckpt_path
62
63
64
def train(
65
    gpu: str,
66
    config_path: Union[str, List[str]],
67
    ckpt_path: str,
68
    num_workers: int = 1,
69
    gpu_allow_growth: bool = True,
70
    exp_name: str = "",
71
    log_dir: str = "logs",
72
    max_epochs: int = -1,
73
):
74
    """
75
    Function to train a model.
76
77
    :param gpu: which local gpu to use to train.
78
    :param config_path: path to configuration set up.
79
    :param ckpt_path: where to store training checkpoints.
80
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
81
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
82
    :param log_dir: path of the log directory.
83
    :param exp_name: experiment name.
84
    :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration.
85
    """
86
    # set env variables
87
    if gpu is not None:
88
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
89
        os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = (
90
            "true" if gpu_allow_growth else "false"
91
        )
92 View Code Duplication
    if num_workers <= 0:  # pragma: no cover
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
93
        logger.info(
94
            "Limiting CPU usage by setting environment variables "
95
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
96
            "This may slow down the training. "
97
            "Please use --num_workers flag to modify the behavior. "
98
            "Setting to 0 or negative values will remove the limitation.",
99
            num_workers,
100
        )
101
        # limit CPU usage
102
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
103
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
104
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
105
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)
106
107
    # load config
108
    config, log_dir, ckpt_path = build_config(
109
        config_path=config_path,
110
        log_dir=log_dir,
111
        exp_name=exp_name,
112
        ckpt_path=ckpt_path,
113
        max_epochs=max_epochs,
114
    )
115
116
    # build dataset
117
    data_loader_train, dataset_train, steps_per_epoch_train = build_dataset(
118
        dataset_config=config["dataset"],
119
        preprocess_config=config["train"]["preprocess"],
120
        split="train",
121
        training=True,
122
        repeat=True,
123
    )
124
    assert data_loader_train is not None  # train data should not be None
125
    data_loader_val, dataset_val, steps_per_epoch_val = build_dataset(
126
        dataset_config=config["dataset"],
127
        preprocess_config=config["train"]["preprocess"],
128
        split="valid",
129
        training=False,
130
        repeat=True,
131
    )
132
133
    # use strategy to support multiple GPUs
134
    # the network is mirrored in each GPU so that we can use larger batch size
135
    # https://www.tensorflow.org/guide/distributed_training
136
    # only model, optimizer and metrics need to be defined inside the strategy
137
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
138
    batch_size = config["train"]["preprocess"]["batch_size"]
139
    if num_devices > 1:  # pragma: no cover
140
        strategy = tf.distribute.MirroredStrategy()
141
        if batch_size % num_devices != 0:
142
            raise ValueError(
143
                f"batch size {batch_size} can not be divided evenly "
144
                f"by the number of devices."
145
            )
146
    else:
147
        strategy = tf.distribute.get_strategy()
148
    with strategy.scope():
149
        model: tf.keras.Model = REGISTRY.build_model(
150
            config=dict(
151
                name=config["train"]["method"],
152
                moving_image_size=data_loader_train.moving_image_shape,
153
                fixed_image_size=data_loader_train.fixed_image_shape,
154
                index_size=data_loader_train.num_indices,
155
                labeled=config["dataset"]["train"]["labeled"],
156
                batch_size=batch_size,
157
                config=config["train"],
158
            )
159
        )
160
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
161
        model.compile(optimizer=optimizer)
162
        model.plot_model(output_dir=log_dir)
163
164
    # build callbacks
165
    tensorboard_callback = tf.keras.callbacks.TensorBoard(
166
        log_dir=log_dir,
167
        histogram_freq=config["train"]["save_period"],
168
        update_freq=config["train"].get("update_freq", "epoch"),
169
    )
170
    ckpt_callback, initial_epoch = build_checkpoint_callback(
171
        model=model,
172
        dataset=dataset_train,
173
        log_dir=log_dir,
174
        save_period=config["train"]["save_period"],
175
        ckpt_path=ckpt_path,
176
    )
177
    callbacks = [tensorboard_callback, ckpt_callback]
178
179
    # train
180
    # it's necessary to define the steps_per_epoch
181
    # and validation_steps to prevent errors like
182
    # BaseCollectiveExecutor::StartAbort Out of range: End of sequence
183
    model.fit(
184
        x=dataset_train,
185
        steps_per_epoch=steps_per_epoch_train,
186
        initial_epoch=initial_epoch,
187
        epochs=config["train"]["epochs"],
188
        validation_data=dataset_val,
189
        validation_steps=steps_per_epoch_val,
190
        callbacks=callbacks,
191
    )
192
193
    # close file loaders in data loaders after training
194
    data_loader_train.close()
195
    if data_loader_val is not None:
196
        data_loader_val.close()
197
198
199
def main(args=None):
200
    """
201
    Entry point for train script.
202
203
    :param args: arguments
204
    """
205
206
    parser = argparse.ArgumentParser()
207
208
    parser.add_argument(
209
        "--gpu",
210
        "-g",
211
        help="GPU index for training."
212
        "-g for using GPU remotely"
213
        '-g "" for using CPU'
214
        '-g "0" for using GPU 0'
215
        '-g "0,1" for using GPU 0 and 1.',
216
        type=str,
217
        required=False,
218
    )
219
220
    parser.add_argument(
221
        "--gpu_allow_growth",
222
        "-gr",
223
        help="Prevent TensorFlow from reserving all available GPU memory",
224
        default=False,
225
    )
226
227
    parser.add_argument(
228
        "--num_workers",
229
        help="Number of CPUs to be used, <= 0 means unlimited.",
230
        type=int,
231
        default=1,
232
    )
233
234
    parser.add_argument(
235
        "--ckpt_path",
236
        "-k",
237
        help="Path of the saved model checkpoint to load."
238
        "No need to provide if start training from scratch.",
239
        default="",
240
        type=str,
241
        required=False,
242
    )
243
244
    parser.add_argument(
245
        "--log_dir", help="Path of log directory.", default="logs", type=str
246
    )
247
248
    parser.add_argument(
249
        "--exp_name",
250
        "-l",
251
        help="Name of log directory."
252
        "The directory is under log root, e.g. logs/ by default."
253
        "If not provided, a timestamp based folder will be created.",
254
        default="",
255
        type=str,
256
    )
257
258
    parser.add_argument(
259
        "--config_path",
260
        "-c",
261
        help="Path of config, must end with .yaml. Can pass multiple paths.",
262
        type=str,
263
        nargs="+",
264
        required=True,
265
    )
266
267
    parser.add_argument(
268
        "--max_epochs",
269
        help="The maximum number of epochs, -1 means following configuration.",
270
        type=int,
271
        default=-1,
272
    )
273
274
    args = parser.parse_args(args)
275
    train(
276
        gpu=args.gpu,
277
        config_path=args.config_path,
278
        num_workers=args.num_workers,
279
        gpu_allow_growth=args.gpu_allow_growth,
280
        ckpt_path=args.ckpt_path,
281
        log_dir=args.log_dir,
282
        exp_name=args.exp_name,
283
        max_epochs=args.max_epochs,
284
    )
285
286
287
if __name__ == "__main__":
288
    main()  # pragma: no cover
289