1
|
|
|
# coding=utf-8 |
2
|
|
|
|
3
|
|
|
""" |
4
|
|
|
Module to train a network using init files and a CLI. |
5
|
|
|
""" |
6
|
|
|
|
7
|
|
|
import argparse |
8
|
|
|
import os |
9
|
|
|
from typing import Dict, List, Tuple, Union |
10
|
|
|
|
11
|
|
|
import tensorflow as tf |
12
|
|
|
|
13
|
|
|
import deepreg.config.parser as config_parser |
14
|
|
|
import deepreg.model.optimizer as opt |
15
|
|
|
from deepreg.callback import build_checkpoint_callback |
16
|
|
|
from deepreg.registry import REGISTRY |
17
|
|
|
from deepreg.util import build_dataset, build_log_dir |
18
|
|
|
|
19
|
|
|
|
20
|
|
|
def build_config( |
21
|
|
|
config_path: Union[str, List[str]], |
22
|
|
|
log_dir: str, |
23
|
|
|
exp_name: str, |
24
|
|
|
ckpt_path: str, |
25
|
|
|
max_epochs: int = -1, |
26
|
|
|
) -> Tuple[Dict, str, str]: |
27
|
|
|
""" |
28
|
|
|
Function to initialise log directories, |
29
|
|
|
assert that checkpointed model is the right |
30
|
|
|
type and to parse the configuration for training. |
31
|
|
|
|
32
|
|
|
:param config_path: list of str, path to config file |
33
|
|
|
:param log_dir: path of the log directory |
34
|
|
|
:param exp_name: name of the experiment |
35
|
|
|
:param ckpt_path: path where model is stored. |
36
|
|
|
:param max_epochs: if max_epochs > 0, use it to overwrite the configuration |
37
|
|
|
:return: - config: a dictionary saving configuration |
38
|
|
|
- exp_name: the path of directory to save logs |
39
|
|
|
""" |
40
|
|
|
|
41
|
|
|
# init log directory |
42
|
|
|
log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name) |
43
|
|
|
|
44
|
|
|
# load config |
45
|
|
|
config = config_parser.load_configs(config_path) |
46
|
|
|
|
47
|
|
|
# replace the ~ with user home path |
48
|
|
|
ckpt_path = os.path.expanduser(ckpt_path) |
49
|
|
|
|
50
|
|
|
# overwrite epochs and save_period if necessary |
51
|
|
|
if max_epochs > 0: |
52
|
|
|
config["train"]["epochs"] = max_epochs |
53
|
|
|
config["train"]["save_period"] = min(max_epochs, config["train"]["save_period"]) |
54
|
|
|
|
55
|
|
|
# backup config |
56
|
|
|
config_parser.save(config=config, out_dir=log_dir) |
57
|
|
|
|
58
|
|
|
return config, log_dir, ckpt_path |
59
|
|
|
|
60
|
|
|
|
61
|
|
|
def train( |
62
|
|
|
gpu: str, |
63
|
|
|
config_path: Union[str, List[str]], |
64
|
|
|
gpu_allow_growth: bool, |
65
|
|
|
ckpt_path: str, |
66
|
|
|
exp_name: str = "", |
67
|
|
|
log_dir: str = "logs", |
68
|
|
|
max_epochs: int = -1, |
69
|
|
|
): |
70
|
|
|
""" |
71
|
|
|
Function to train a model. |
72
|
|
|
|
73
|
|
|
:param gpu: which local gpu to use to train. |
74
|
|
|
:param config_path: path to configuration set up. |
75
|
|
|
:param gpu_allow_growth: whether to allocate whole GPU memory for training. |
76
|
|
|
:param ckpt_path: where to store training checkpoints. |
77
|
|
|
:param log_dir: path of the log directory. |
78
|
|
|
:param exp_name: experiment name. |
79
|
|
|
:param max_epochs: if max_epochs > 0, will use it to overwrite the configuration. |
80
|
|
|
""" |
81
|
|
|
# set env variables |
82
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
83
|
|
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false" |
84
|
|
|
|
85
|
|
|
# load config |
86
|
|
|
config, log_dir, ckpt_path = build_config( |
87
|
|
|
config_path=config_path, |
88
|
|
|
log_dir=log_dir, |
89
|
|
|
exp_name=exp_name, |
90
|
|
|
ckpt_path=ckpt_path, |
91
|
|
|
max_epochs=max_epochs, |
92
|
|
|
) |
93
|
|
|
|
94
|
|
|
# build dataset |
95
|
|
|
data_loader_train, dataset_train, steps_per_epoch_train = build_dataset( |
96
|
|
|
dataset_config=config["dataset"], |
97
|
|
|
preprocess_config=config["train"]["preprocess"], |
98
|
|
|
mode="train", |
99
|
|
|
training=True, |
100
|
|
|
repeat=True, |
101
|
|
|
) |
102
|
|
|
assert data_loader_train is not None # train data should not be None |
103
|
|
|
data_loader_val, dataset_val, steps_per_epoch_val = build_dataset( |
104
|
|
|
dataset_config=config["dataset"], |
105
|
|
|
preprocess_config=config["train"]["preprocess"], |
106
|
|
|
mode="valid", |
107
|
|
|
training=False, |
108
|
|
|
repeat=True, |
109
|
|
|
) |
110
|
|
|
|
111
|
|
|
# use strategy to support multiple GPUs |
112
|
|
|
# the network is mirrored in each GPU so that we can use larger batch size |
113
|
|
|
# https://www.tensorflow.org/guide/distributed_training |
114
|
|
|
# only model, optimizer and metrics need to be defined inside the strategy |
115
|
|
|
num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) |
116
|
|
|
batch_size = config["train"]["preprocess"]["batch_size"] |
117
|
|
|
if num_devices > 1: # pragma: no cover |
118
|
|
|
strategy = tf.distribute.MirroredStrategy() |
119
|
|
|
if batch_size % num_devices != 0: |
120
|
|
|
raise ValueError( |
121
|
|
|
f"batch size {batch_size} can not be divided evenly " |
122
|
|
|
f"by the number of devices." |
123
|
|
|
) |
124
|
|
|
else: |
125
|
|
|
strategy = tf.distribute.get_strategy() |
126
|
|
|
with strategy.scope(): |
127
|
|
|
model: tf.keras.Model = REGISTRY.build_model( |
128
|
|
|
config=dict( |
129
|
|
|
name=config["train"]["method"], |
130
|
|
|
moving_image_size=data_loader_train.moving_image_shape, |
131
|
|
|
fixed_image_size=data_loader_train.fixed_image_shape, |
132
|
|
|
index_size=data_loader_train.num_indices, |
133
|
|
|
labeled=config["dataset"]["labeled"], |
134
|
|
|
batch_size=batch_size, |
135
|
|
|
config=config["train"], |
136
|
|
|
) |
137
|
|
|
) |
138
|
|
|
optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) |
139
|
|
|
model.compile(optimizer=optimizer) |
140
|
|
|
model.plot_model(output_dir=log_dir) |
141
|
|
|
|
142
|
|
|
# build callbacks |
143
|
|
|
tensorboard_callback = tf.keras.callbacks.TensorBoard( |
144
|
|
|
log_dir=log_dir, |
145
|
|
|
histogram_freq=config["train"]["save_period"], |
146
|
|
|
update_freq=config["train"].get("update_freq", "epoch"), |
147
|
|
|
) |
148
|
|
|
ckpt_callback, initial_epoch = build_checkpoint_callback( |
149
|
|
|
model=model, |
150
|
|
|
dataset=dataset_train, |
151
|
|
|
log_dir=log_dir, |
152
|
|
|
save_period=config["train"]["save_period"], |
153
|
|
|
ckpt_path=ckpt_path, |
154
|
|
|
) |
155
|
|
|
callbacks = [tensorboard_callback, ckpt_callback] |
156
|
|
|
|
157
|
|
|
# train |
158
|
|
|
# it's necessary to define the steps_per_epoch |
159
|
|
|
# and validation_steps to prevent errors like |
160
|
|
|
# BaseCollectiveExecutor::StartAbort Out of range: End of sequence |
161
|
|
|
model.fit( |
162
|
|
|
x=dataset_train, |
163
|
|
|
steps_per_epoch=steps_per_epoch_train, |
164
|
|
|
initial_epoch=initial_epoch, |
165
|
|
|
epochs=config["train"]["epochs"], |
166
|
|
|
validation_data=dataset_val, |
167
|
|
|
validation_steps=steps_per_epoch_val, |
168
|
|
|
callbacks=callbacks, |
169
|
|
|
) |
170
|
|
|
|
171
|
|
|
# close file loaders in data loaders after training |
172
|
|
|
data_loader_train.close() |
173
|
|
|
if data_loader_val is not None: |
174
|
|
|
data_loader_val.close() |
175
|
|
|
|
176
|
|
|
|
177
|
|
|
def main(args=None): |
178
|
|
|
""" |
179
|
|
|
Entry point for train script. |
180
|
|
|
|
181
|
|
|
:param args: arguments |
182
|
|
|
""" |
183
|
|
|
|
184
|
|
|
parser = argparse.ArgumentParser() |
185
|
|
|
|
186
|
|
|
parser.add_argument( |
187
|
|
|
"--gpu", |
188
|
|
|
"-g", |
189
|
|
|
help="GPU index for training." |
190
|
|
|
'-g "" for using CPU' |
191
|
|
|
'-g "0" for using GPU 0' |
192
|
|
|
'-g "0,1" for using GPU 0 and 1.', |
193
|
|
|
type=str, |
194
|
|
|
required=True, |
195
|
|
|
) |
196
|
|
|
|
197
|
|
|
parser.add_argument( |
198
|
|
|
"--gpu_allow_growth", |
199
|
|
|
"-gr", |
200
|
|
|
help="Prevent TensorFlow from reserving all available GPU memory", |
201
|
|
|
default=False, |
202
|
|
|
) |
203
|
|
|
|
204
|
|
|
parser.add_argument( |
205
|
|
|
"--ckpt_path", |
206
|
|
|
"-k", |
207
|
|
|
help="Path of the saved model checkpoint to load." |
208
|
|
|
"No need to provide if start training from scratch.", |
209
|
|
|
default="", |
210
|
|
|
type=str, |
211
|
|
|
required=False, |
212
|
|
|
) |
213
|
|
|
|
214
|
|
|
parser.add_argument( |
215
|
|
|
"--log_dir", help="Path of log directory.", default="logs", type=str |
216
|
|
|
) |
217
|
|
|
|
218
|
|
|
parser.add_argument( |
219
|
|
|
"--exp_name", |
220
|
|
|
"-l", |
221
|
|
|
help="Name of log directory." |
222
|
|
|
"The directory is under log root, e.g. logs/ by default." |
223
|
|
|
"If not provided, a timestamp based folder will be created.", |
224
|
|
|
default="", |
225
|
|
|
type=str, |
226
|
|
|
) |
227
|
|
|
|
228
|
|
|
parser.add_argument( |
229
|
|
|
"--config_path", |
230
|
|
|
"-c", |
231
|
|
|
help="Path of config, must end with .yaml. Can pass multiple paths.", |
232
|
|
|
type=str, |
233
|
|
|
nargs="+", |
234
|
|
|
required=True, |
235
|
|
|
) |
236
|
|
|
|
237
|
|
|
parser.add_argument( |
238
|
|
|
"--max_epochs", |
239
|
|
|
help="The maximum number of epochs, -1 means following configuration.", |
240
|
|
|
type=int, |
241
|
|
|
default=-1, |
242
|
|
|
) |
243
|
|
|
|
244
|
|
|
args = parser.parse_args(args) |
245
|
|
|
train( |
246
|
|
|
gpu=args.gpu, |
247
|
|
|
config_path=args.config_path, |
248
|
|
|
gpu_allow_growth=args.gpu_allow_growth, |
249
|
|
|
ckpt_path=args.ckpt_path, |
250
|
|
|
log_dir=args.log_dir, |
251
|
|
|
exp_name=args.exp_name, |
252
|
|
|
max_epochs=args.max_epochs, |
253
|
|
|
) |
254
|
|
|
|
255
|
|
|
|
256
|
|
|
if __name__ == "__main__": |
257
|
|
|
main() # pragma: no cover |
258
|
|
|
|