1
|
|
|
# coding=utf-8 |
2
|
|
|
|
3
|
|
|
""" |
4
|
|
|
Module to train a network using init files and a CLI. |
5
|
|
|
""" |
6
|
|
|
|
7
|
|
|
import argparse |
8
|
|
|
import os |
9
|
|
|
from typing import Dict, List, Tuple, Union |
10
|
|
|
|
11
|
|
|
import tensorflow as tf |
12
|
|
|
|
13
|
|
|
import deepreg.config.parser as config_parser |
14
|
|
|
import deepreg.model.optimizer as opt |
15
|
|
|
from deepreg import log |
16
|
|
|
from deepreg.callback import build_checkpoint_callback |
17
|
|
|
from deepreg.registry import REGISTRY |
18
|
|
|
from deepreg.util import build_dataset, build_log_dir |
19
|
|
|
|
20
|
|
|
logger = log.get(__name__) |
21
|
|
|
|
22
|
|
|
|
23
|
|
|
def build_config( |
24
|
|
|
config_path: Union[str, List[str]], |
25
|
|
|
log_dir: str, |
26
|
|
|
exp_name: str, |
27
|
|
|
ckpt_path: str, |
28
|
|
|
max_epochs: int = -1, |
29
|
|
|
) -> Tuple[Dict, str, str]: |
30
|
|
|
""" |
31
|
|
|
Function to initialise log directories, |
32
|
|
|
assert that checkpointed model is the right |
33
|
|
|
type and to parse the configuration for training. |
34
|
|
|
|
35
|
|
|
:param config_path: list of str, path to config file |
36
|
|
|
:param log_dir: path of the log directory |
37
|
|
|
:param exp_name: name of the experiment |
38
|
|
|
:param ckpt_path: path where model is stored. |
39
|
|
|
:param max_epochs: if max_epochs > 0, use it to overwrite the configuration |
40
|
|
|
:return: - config: a dictionary saving configuration |
41
|
|
|
- exp_name: the path of directory to save logs |
42
|
|
|
""" |
43
|
|
|
|
44
|
|
|
# init log directory |
45
|
|
|
log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name) |
46
|
|
|
|
47
|
|
|
# load config |
48
|
|
|
config = config_parser.load_configs(config_path) |
49
|
|
|
|
50
|
|
|
# replace the ~ with user home path |
51
|
|
|
ckpt_path = os.path.expanduser(ckpt_path) |
52
|
|
|
|
53
|
|
|
# overwrite epochs and save_period if necessary |
54
|
|
|
if max_epochs > 0: |
55
|
|
|
config["train"]["epochs"] = max_epochs |
56
|
|
|
config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"]) |
57
|
|
|
|
58
|
|
|
# backup config |
59
|
|
|
config_parser.save(config=config, out_dir=log_dir) |
60
|
|
|
|
61
|
|
|
return config, log_dir, ckpt_path |
62
|
|
|
|
63
|
|
|
|
64
|
|
|
def train( |
65
|
|
|
gpu: str, |
66
|
|
|
config_path: Union[str, List[str]], |
67
|
|
|
ckpt_path: str, |
68
|
|
|
num_workers: int = 1, |
69
|
|
|
gpu_allow_growth: bool = True, |
70
|
|
|
exp_name: str = "", |
71
|
|
|
log_dir: str = "logs", |
72
|
|
|
max_epochs: int = -1, |
73
|
|
|
): |
74
|
|
|
""" |
75
|
|
|
Function to train a model. |
76
|
|
|
|
77
|
|
|
:param gpu: which local gpu to use to train. |
78
|
|
|
:param config_path: path to configuration set up. |
79
|
|
|
:param ckpt_path: where to store training checkpoints. |
80
|
|
|
:param num_workers: number of cpu cores to be used, <=0 means not limited. |
81
|
|
|
:param gpu_allow_growth: whether to allocate whole GPU memory for training. |
82
|
|
|
:param log_dir: path of the log directory. |
83
|
|
|
:param exp_name: experiment name. |
84
|
|
|
:param max_epochs: if max_epochs > 0, will use it to overwrite the configuration. |
85
|
|
|
""" |
86
|
|
|
# set env variables |
87
|
|
|
if gpu is not None: |
88
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
89
|
|
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = ( |
90
|
|
|
"true" if gpu_allow_growth else "false" |
91
|
|
|
) |
92
|
|
View Code Duplication |
if num_workers <= 0: # pragma: no cover |
|
|
|
|
93
|
|
|
logger.info( |
94
|
|
|
"Limiting CPU usage by setting environment variables " |
95
|
|
|
"OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. " |
96
|
|
|
"This may slow down the training. " |
97
|
|
|
"Please use --num_workers flag to modify the behavior. " |
98
|
|
|
"Setting to 0 or negative values will remove the limitation.", |
99
|
|
|
num_workers, |
100
|
|
|
) |
101
|
|
|
# limit CPU usage |
102
|
|
|
# https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232 |
103
|
|
|
os.environ["OMP_NUM_THREADS"] = str(num_workers) |
104
|
|
|
os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers) |
105
|
|
|
os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers) |
106
|
|
|
|
107
|
|
|
# load config |
108
|
|
|
config, log_dir, ckpt_path = build_config( |
109
|
|
|
config_path=config_path, |
110
|
|
|
log_dir=log_dir, |
111
|
|
|
exp_name=exp_name, |
112
|
|
|
ckpt_path=ckpt_path, |
113
|
|
|
max_epochs=max_epochs, |
114
|
|
|
) |
115
|
|
|
|
116
|
|
|
# build dataset |
117
|
|
|
data_loader_train, dataset_train, steps_per_epoch_train = build_dataset( |
118
|
|
|
dataset_config=config["dataset"], |
119
|
|
|
preprocess_config=config["train"]["preprocess"], |
120
|
|
|
split="train", |
121
|
|
|
training=True, |
122
|
|
|
repeat=True, |
123
|
|
|
) |
124
|
|
|
assert data_loader_train is not None # train data should not be None |
125
|
|
|
data_loader_val, dataset_val, steps_per_epoch_val = build_dataset( |
126
|
|
|
dataset_config=config["dataset"], |
127
|
|
|
preprocess_config=config["train"]["preprocess"], |
128
|
|
|
split="valid", |
129
|
|
|
training=False, |
130
|
|
|
repeat=True, |
131
|
|
|
) |
132
|
|
|
|
133
|
|
|
# use strategy to support multiple GPUs |
134
|
|
|
# the network is mirrored in each GPU so that we can use larger batch size |
135
|
|
|
# https://www.tensorflow.org/guide/distributed_training |
136
|
|
|
# only model, optimizer and metrics need to be defined inside the strategy |
137
|
|
|
num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) |
138
|
|
|
batch_size = config["train"]["preprocess"]["batch_size"] |
139
|
|
|
if num_devices > 1: # pragma: no cover |
140
|
|
|
strategy = tf.distribute.MirroredStrategy() |
141
|
|
|
if batch_size % num_devices != 0: |
142
|
|
|
raise ValueError( |
143
|
|
|
f"batch size {batch_size} can not be divided evenly " |
144
|
|
|
f"by the number of devices." |
145
|
|
|
) |
146
|
|
|
else: |
147
|
|
|
strategy = tf.distribute.get_strategy() |
148
|
|
|
with strategy.scope(): |
149
|
|
|
model: tf.keras.Model = REGISTRY.build_model( |
150
|
|
|
config=dict( |
151
|
|
|
name=config["train"]["method"], |
152
|
|
|
moving_image_size=data_loader_train.moving_image_shape, |
153
|
|
|
fixed_image_size=data_loader_train.fixed_image_shape, |
154
|
|
|
index_size=data_loader_train.num_indices, |
155
|
|
|
labeled=config["dataset"]["train"]["labeled"], |
156
|
|
|
batch_size=batch_size, |
157
|
|
|
config=config["train"], |
158
|
|
|
) |
159
|
|
|
) |
160
|
|
|
optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) |
161
|
|
|
model.compile(optimizer=optimizer) |
162
|
|
|
model.plot_model(output_dir=log_dir) |
163
|
|
|
|
164
|
|
|
# build callbacks |
165
|
|
|
tensorboard_callback = tf.keras.callbacks.TensorBoard( |
166
|
|
|
log_dir=log_dir, |
167
|
|
|
histogram_freq=config["train"]["save_period"], |
168
|
|
|
update_freq=config["train"].get("update_freq", "epoch"), |
169
|
|
|
) |
170
|
|
|
ckpt_callback, initial_epoch = build_checkpoint_callback( |
171
|
|
|
model=model, |
172
|
|
|
dataset=dataset_train, |
173
|
|
|
log_dir=log_dir, |
174
|
|
|
save_period=config["train"]["save_period"], |
175
|
|
|
ckpt_path=ckpt_path, |
176
|
|
|
) |
177
|
|
|
callbacks = [tensorboard_callback, ckpt_callback] |
178
|
|
|
|
179
|
|
|
# train |
180
|
|
|
# it's necessary to define the steps_per_epoch |
181
|
|
|
# and validation_steps to prevent errors like |
182
|
|
|
# BaseCollectiveExecutor::StartAbort Out of range: End of sequence |
183
|
|
|
model.fit( |
184
|
|
|
x=dataset_train, |
185
|
|
|
steps_per_epoch=steps_per_epoch_train, |
186
|
|
|
initial_epoch=initial_epoch, |
187
|
|
|
epochs=config["train"]["epochs"], |
188
|
|
|
validation_data=dataset_val, |
189
|
|
|
validation_steps=steps_per_epoch_val, |
190
|
|
|
callbacks=callbacks, |
191
|
|
|
) |
192
|
|
|
|
193
|
|
|
# close file loaders in data loaders after training |
194
|
|
|
data_loader_train.close() |
195
|
|
|
if data_loader_val is not None: |
196
|
|
|
data_loader_val.close() |
197
|
|
|
|
198
|
|
|
|
199
|
|
|
def main(args=None): |
200
|
|
|
""" |
201
|
|
|
Entry point for train script. |
202
|
|
|
|
203
|
|
|
:param args: arguments |
204
|
|
|
""" |
205
|
|
|
|
206
|
|
|
parser = argparse.ArgumentParser() |
207
|
|
|
|
208
|
|
|
parser.add_argument( |
209
|
|
|
"--gpu", |
210
|
|
|
"-g", |
211
|
|
|
help="GPU index for training." |
212
|
|
|
"-g for using GPU remotely" |
213
|
|
|
'-g "" for using CPU' |
214
|
|
|
'-g "0" for using GPU 0' |
215
|
|
|
'-g "0,1" for using GPU 0 and 1.', |
216
|
|
|
type=str, |
217
|
|
|
required=False, |
218
|
|
|
) |
219
|
|
|
|
220
|
|
|
parser.add_argument( |
221
|
|
|
"--gpu_allow_growth", |
222
|
|
|
"-gr", |
223
|
|
|
help="Prevent TensorFlow from reserving all available GPU memory", |
224
|
|
|
default=False, |
225
|
|
|
) |
226
|
|
|
|
227
|
|
|
parser.add_argument( |
228
|
|
|
"--num_workers", |
229
|
|
|
help="Number of CPUs to be used, <= 0 means unlimited.", |
230
|
|
|
type=int, |
231
|
|
|
default=1, |
232
|
|
|
) |
233
|
|
|
|
234
|
|
|
parser.add_argument( |
235
|
|
|
"--ckpt_path", |
236
|
|
|
"-k", |
237
|
|
|
help="Path of the saved model checkpoint to load." |
238
|
|
|
"No need to provide if start training from scratch.", |
239
|
|
|
default="", |
240
|
|
|
type=str, |
241
|
|
|
required=False, |
242
|
|
|
) |
243
|
|
|
|
244
|
|
|
parser.add_argument( |
245
|
|
|
"--log_dir", help="Path of log directory.", default="logs", type=str |
246
|
|
|
) |
247
|
|
|
|
248
|
|
|
parser.add_argument( |
249
|
|
|
"--exp_name", |
250
|
|
|
"-l", |
251
|
|
|
help="Name of log directory." |
252
|
|
|
"The directory is under log root, e.g. logs/ by default." |
253
|
|
|
"If not provided, a timestamp based folder will be created.", |
254
|
|
|
default="", |
255
|
|
|
type=str, |
256
|
|
|
) |
257
|
|
|
|
258
|
|
|
parser.add_argument( |
259
|
|
|
"--config_path", |
260
|
|
|
"-c", |
261
|
|
|
help="Path of config, must end with .yaml. Can pass multiple paths.", |
262
|
|
|
type=str, |
263
|
|
|
nargs="+", |
264
|
|
|
required=True, |
265
|
|
|
) |
266
|
|
|
|
267
|
|
|
parser.add_argument( |
268
|
|
|
"--max_epochs", |
269
|
|
|
help="The maximum number of epochs, -1 means following configuration.", |
270
|
|
|
type=int, |
271
|
|
|
default=-1, |
272
|
|
|
) |
273
|
|
|
|
274
|
|
|
args = parser.parse_args(args) |
275
|
|
|
train( |
276
|
|
|
gpu=args.gpu, |
277
|
|
|
config_path=args.config_path, |
278
|
|
|
num_workers=args.num_workers, |
279
|
|
|
gpu_allow_growth=args.gpu_allow_growth, |
280
|
|
|
ckpt_path=args.ckpt_path, |
281
|
|
|
log_dir=args.log_dir, |
282
|
|
|
exp_name=args.exp_name, |
283
|
|
|
max_epochs=args.max_epochs, |
284
|
|
|
) |
285
|
|
|
|
286
|
|
|
|
287
|
|
|
if __name__ == "__main__": |
288
|
|
|
main() # pragma: no cover |
289
|
|
|
|