1
|
|
|
# coding=utf-8 |
2
|
|
|
|
3
|
|
|
""" |
4
|
|
|
Module to perform predictions on data using |
5
|
|
|
command line interface. |
6
|
|
|
""" |
7
|
|
|
|
8
|
|
|
import argparse |
9
|
|
|
import logging |
10
|
|
|
import os |
11
|
|
|
import shutil |
12
|
|
|
from typing import Dict, List, Tuple, Union |
13
|
|
|
|
14
|
|
|
import numpy as np |
15
|
|
|
import tensorflow as tf |
16
|
|
|
|
17
|
|
|
import deepreg.config.parser as config_parser |
18
|
|
|
import deepreg.model.layer_util as layer_util |
19
|
|
|
import deepreg.model.optimizer as opt |
20
|
|
|
from deepreg.callback import build_checkpoint_callback |
21
|
|
|
from deepreg.registry import REGISTRY |
22
|
|
|
from deepreg.util import ( |
23
|
|
|
build_dataset, |
24
|
|
|
build_log_dir, |
25
|
|
|
calculate_metrics, |
26
|
|
|
save_array, |
27
|
|
|
save_metric_dict, |
28
|
|
|
) |
29
|
|
|
|
30
|
|
|
|
31
|
|
|
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]: |
32
|
|
|
""" |
33
|
|
|
Create directory for saving the paired data |
34
|
|
|
|
35
|
|
|
:param indices: indices of the pair, the last one is for label |
36
|
|
|
:param save_dir: directory of output |
37
|
|
|
:return: - save_dir, str, directory for saving the moving/fixed image |
38
|
|
|
- label_dir, str, directory for saving the rest outputs |
39
|
|
|
""" |
40
|
|
|
|
41
|
|
|
# cast indices to string and init directory name |
42
|
|
|
pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]]) |
43
|
|
|
pair_dir = os.path.join(save_dir, pair_index) |
44
|
|
|
os.makedirs(pair_dir, exist_ok=True) |
45
|
|
|
|
46
|
|
|
if indices[-1] >= 0: |
47
|
|
|
label_index = f"label_{indices[-1]}" |
48
|
|
|
label_dir = os.path.join(pair_dir, label_index) |
49
|
|
|
os.makedirs(label_dir, exist_ok=True) |
50
|
|
|
else: |
51
|
|
|
label_dir = pair_dir |
52
|
|
|
|
53
|
|
|
return pair_dir, label_dir |
54
|
|
|
|
55
|
|
|
|
56
|
|
|
def predict_on_dataset( |
57
|
|
|
dataset: tf.data.Dataset, |
58
|
|
|
fixed_grid_ref: tf.Tensor, |
59
|
|
|
model: tf.keras.Model, |
60
|
|
|
model_method: str, |
61
|
|
|
save_dir: str, |
62
|
|
|
save_nifti: bool, |
63
|
|
|
save_png: bool, |
64
|
|
|
): |
65
|
|
|
""" |
66
|
|
|
Function to predict results from a dataset from some model |
67
|
|
|
|
68
|
|
|
:param dataset: where data is stored |
69
|
|
|
:param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) |
70
|
|
|
:param model: model to be used for prediction |
71
|
|
|
:param model_method: ddf / dvf / affine / conditional |
72
|
|
|
:param save_dir: path to store dir |
73
|
|
|
:param save_nifti: if true, outputs will be saved in nifti format |
74
|
|
|
:param save_png: if true, outputs will be saved in png format |
75
|
|
|
""" |
76
|
|
|
# remove the save_dir in case it exists |
77
|
|
|
if os.path.exists(save_dir): |
78
|
|
|
shutil.rmtree(save_dir) # pragma: no cover |
79
|
|
|
|
80
|
|
|
sample_index_strs = [] |
81
|
|
|
metric_lists = [] |
82
|
|
|
for _, inputs in enumerate(dataset): |
83
|
|
|
batch_size = inputs[list(inputs.keys())[0]].shape[0] |
84
|
|
|
outputs = model.predict(x=inputs, batch_size=batch_size) |
85
|
|
|
indices, processed = model.postprocess(inputs=inputs, outputs=outputs) |
86
|
|
|
|
87
|
|
|
# convert to np arrays |
88
|
|
|
indices = indices.numpy() |
89
|
|
|
processed = { |
90
|
|
|
k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2]) |
91
|
|
|
for k, v in processed.items() |
92
|
|
|
} |
93
|
|
|
|
94
|
|
|
# save images of inputs and outputs |
95
|
|
|
for sample_index in range(batch_size): |
96
|
|
|
# save label independent tensors under pair_dir, otherwise under label_dir |
97
|
|
|
|
98
|
|
|
# init output path |
99
|
|
|
indices_i = indices[sample_index, :].astype(int).tolist() |
100
|
|
|
pair_dir, label_dir = build_pair_output_path( |
101
|
|
|
indices=indices_i, save_dir=save_dir |
102
|
|
|
) |
103
|
|
|
|
104
|
|
|
for name, (arr, normalize, on_label) in processed.items(): |
105
|
|
|
if name == "theta": |
106
|
|
|
np.savetxt( |
107
|
|
|
fname=os.path.join(pair_dir, "affine.txt"), |
108
|
|
|
X=arr[sample_index, :, :], |
109
|
|
|
delimiter=",", |
110
|
|
|
) |
111
|
|
|
continue |
112
|
|
|
|
113
|
|
|
arr_save_dir = label_dir if on_label else pair_dir |
114
|
|
|
save_array( |
115
|
|
|
save_dir=arr_save_dir, |
116
|
|
|
arr=arr[sample_index, :, :, :], |
117
|
|
|
name=name, |
118
|
|
|
normalize=normalize, # label's value is already in [0, 1] |
119
|
|
|
save_nifti=save_nifti, |
120
|
|
|
save_png=save_png, |
121
|
|
|
overwrite=arr_save_dir == label_dir, |
122
|
|
|
) |
123
|
|
|
|
124
|
|
|
# calculate metric |
125
|
|
|
sample_index_str = "_".join([str(x) for x in indices_i]) |
126
|
|
|
if sample_index_str in sample_index_strs: # pragma: no cover |
127
|
|
|
raise ValueError( |
128
|
|
|
"Sample is repeated, maybe the dataset has been repeated." |
129
|
|
|
) |
130
|
|
|
sample_index_strs.append(sample_index_str) |
131
|
|
|
|
132
|
|
|
metric = calculate_metrics( |
133
|
|
|
fixed_image=processed["fixed_image"][0], |
134
|
|
|
fixed_label=processed["fixed_label"][0] if model.labeled else None, |
135
|
|
|
pred_fixed_image=processed["pred_fixed_image"][0], |
136
|
|
|
pred_fixed_label=processed["pred_fixed_label"][0] |
137
|
|
|
if model.labeled |
138
|
|
|
else None, |
139
|
|
|
fixed_grid_ref=fixed_grid_ref, |
140
|
|
|
sample_index=sample_index, |
141
|
|
|
) |
142
|
|
|
metric["pair_index"] = indices_i[:-1] |
143
|
|
|
metric["label_index"] = indices_i[-1] |
144
|
|
|
metric_lists.append(metric) |
145
|
|
|
|
146
|
|
|
# save metric |
147
|
|
|
save_metric_dict(save_dir=save_dir, metrics=metric_lists) |
148
|
|
|
|
149
|
|
|
|
150
|
|
|
def build_config( |
151
|
|
|
config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str |
152
|
|
|
) -> Tuple[Dict, str, str]: |
153
|
|
|
""" |
154
|
|
|
Function to create new directory to log directory to store results. |
155
|
|
|
|
156
|
|
|
:param config_path: path of configuration files. |
157
|
|
|
:param log_dir: path of the log directory. |
158
|
|
|
:param exp_name: experiment name. |
159
|
|
|
:param ckpt_path: path where model is stored. |
160
|
|
|
:return: - config, configuration dictionary. |
161
|
|
|
- exp_name, path of the directory for saving outputs. |
162
|
|
|
""" |
163
|
|
|
|
164
|
|
|
# init log directory |
165
|
|
|
log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name) |
166
|
|
|
|
167
|
|
|
# replace the ~ with user home path |
168
|
|
|
ckpt_path = os.path.expanduser(ckpt_path) |
169
|
|
|
|
170
|
|
|
# load config |
171
|
|
|
if config_path == "": |
172
|
|
|
# use default config, which should be provided in the log folder |
173
|
|
|
config = config_parser.load_configs( |
174
|
|
|
"/".join(ckpt_path.split("/")[:-2]) + "/config.yaml" |
175
|
|
|
) |
176
|
|
|
else: |
177
|
|
|
# use customized config |
178
|
|
|
logging.warning( |
179
|
|
|
"Using customized configuration." |
180
|
|
|
"The code might break if the config doesn't match the saved model." |
181
|
|
|
) |
182
|
|
|
config = config_parser.load_configs(config_path) |
183
|
|
|
return config, log_dir, ckpt_path |
184
|
|
|
|
185
|
|
|
|
186
|
|
|
def predict( |
187
|
|
|
gpu: str, |
188
|
|
|
ckpt_path: str, |
189
|
|
|
mode: str, |
190
|
|
|
batch_size: int, |
191
|
|
|
exp_name: str, |
192
|
|
|
config_path: Union[str, List[str]], |
193
|
|
|
num_workers: int = 1, |
194
|
|
|
gpu_allow_growth: bool = True, |
195
|
|
|
save_nifti: bool = True, |
196
|
|
|
save_png: bool = True, |
197
|
|
|
log_dir: str = "logs", |
198
|
|
|
): |
199
|
|
|
""" |
200
|
|
|
Function to predict some metrics from the saved model and logging results. |
201
|
|
|
|
202
|
|
|
:param gpu: which env gpu to use. |
203
|
|
|
:param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x. |
204
|
|
|
:param mode: train / valid / test, to define which split of dataset to be evaluated. |
205
|
|
|
:param batch_size: int, batch size to perform predictions. |
206
|
|
|
:param exp_name: name of the experiment. |
207
|
|
|
:param config_path: to overwrite the default config. |
208
|
|
|
:param num_workers: number of cpus to be used, -1 means not limited. |
209
|
|
|
:param gpu_allow_growth: whether to allocate whole GPU memory for training. |
210
|
|
|
:param save_nifti: if true, outputs will be saved in nifti format. |
211
|
|
|
:param save_png: if true, outputs will be saved in png format. |
212
|
|
|
:param log_dir: path of the log directory. |
213
|
|
|
""" |
214
|
|
|
|
215
|
|
|
# env vars |
216
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
217
|
|
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true" |
218
|
|
|
if num_workers > 0: |
219
|
|
|
# Maximum number of threads to use for OpenMP parallel regions. |
220
|
|
|
os.environ["OMP_NUM_THREADS"] = str(num_workers) |
221
|
|
|
# Without setting below 2 environment variables, it didn't work for me. Thanks to @cjw85 |
222
|
|
|
os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers) |
223
|
|
|
os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers) |
224
|
|
|
|
225
|
|
|
# load config |
226
|
|
|
config, log_dir, ckpt_path = build_config( |
227
|
|
|
config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path |
228
|
|
|
) |
229
|
|
|
config["train"]["preprocess"]["batch_size"] = batch_size |
230
|
|
|
|
231
|
|
|
# data |
232
|
|
|
data_loader, dataset, _ = build_dataset( |
233
|
|
|
dataset_config=config["dataset"], |
234
|
|
|
preprocess_config=config["train"]["preprocess"], |
235
|
|
|
mode=mode, |
236
|
|
|
training=False, |
237
|
|
|
repeat=False, |
238
|
|
|
) |
239
|
|
|
assert data_loader is not None |
240
|
|
|
|
241
|
|
|
# use strategy to support multiple GPUs |
242
|
|
|
# the network is mirrored in each GPU so that we can use larger batch size |
243
|
|
|
# https://www.tensorflow.org/guide/distributed_training |
244
|
|
|
# only model, optimizer and metrics need to be defined inside the strategy |
245
|
|
|
num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) |
246
|
|
|
if num_devices > 1: # pragma: no cover |
247
|
|
|
strategy = tf.distribute.MirroredStrategy() |
248
|
|
|
if batch_size % num_devices != 0: |
249
|
|
|
raise ValueError( |
250
|
|
|
f"batch size {batch_size} can not be divided evenly " |
251
|
|
|
f"by the number of devices." |
252
|
|
|
) |
253
|
|
|
else: |
254
|
|
|
strategy = tf.distribute.get_strategy() |
255
|
|
|
with strategy.scope(): |
256
|
|
|
model: tf.keras.Model = REGISTRY.build_model( |
257
|
|
|
config=dict( |
258
|
|
|
name=config["train"]["method"], |
259
|
|
|
moving_image_size=data_loader.moving_image_shape, |
260
|
|
|
fixed_image_size=data_loader.fixed_image_shape, |
261
|
|
|
index_size=data_loader.num_indices, |
262
|
|
|
labeled=config["dataset"]["labeled"], |
263
|
|
|
batch_size=batch_size, |
264
|
|
|
config=config["train"], |
265
|
|
|
) |
266
|
|
|
) |
267
|
|
|
optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) |
268
|
|
|
model.compile(optimizer=optimizer) |
269
|
|
|
model.plot_model(output_dir=log_dir) |
270
|
|
|
|
271
|
|
|
# load weights |
272
|
|
|
if ckpt_path.endswith(".ckpt"): |
273
|
|
|
# for ckpt from tf.keras.callbacks.ModelCheckpoint |
274
|
|
|
# skip warnings because of optimizers |
275
|
|
|
# https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object |
276
|
|
|
model.load_weights(ckpt_path).expect_partial() # pragma: no cover |
277
|
|
|
else: |
278
|
|
|
# for ckpts from ckpt manager callback |
279
|
|
|
_, _ = build_checkpoint_callback( |
280
|
|
|
model=model, |
281
|
|
|
dataset=dataset, |
282
|
|
|
log_dir=log_dir, |
283
|
|
|
save_period=config["train"]["save_period"], |
284
|
|
|
ckpt_path=ckpt_path, |
285
|
|
|
) |
286
|
|
|
|
287
|
|
|
# predict |
288
|
|
|
fixed_grid_ref = tf.expand_dims( |
289
|
|
|
layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0 |
290
|
|
|
) # shape = (1, f_dim1, f_dim2, f_dim3, 3) |
291
|
|
|
predict_on_dataset( |
292
|
|
|
dataset=dataset, |
293
|
|
|
fixed_grid_ref=fixed_grid_ref, |
294
|
|
|
model=model, |
295
|
|
|
model_method=config["train"]["method"], |
296
|
|
|
save_dir=os.path.join(log_dir, "test"), |
297
|
|
|
save_nifti=save_nifti, |
298
|
|
|
save_png=save_png, |
299
|
|
|
) |
300
|
|
|
|
301
|
|
|
# close the opened files in data loaders |
302
|
|
|
data_loader.close() |
303
|
|
|
|
304
|
|
|
|
305
|
|
|
def main(args=None): |
306
|
|
|
""" |
307
|
|
|
Entry point for predict script. |
308
|
|
|
|
309
|
|
|
:param args: |
310
|
|
|
""" |
311
|
|
|
parser = argparse.ArgumentParser() |
312
|
|
|
|
313
|
|
|
parser.add_argument( |
314
|
|
|
"--gpu", |
315
|
|
|
"-g", |
316
|
|
|
help="GPU index for training." |
317
|
|
|
'-g "" for using CPU' |
318
|
|
|
'-g "0" for using GPU 0' |
319
|
|
|
'-g "0,1" for using GPU 0 and 1.', |
320
|
|
|
type=str, |
321
|
|
|
required=True, |
322
|
|
|
) |
323
|
|
|
|
324
|
|
|
parser.add_argument( |
325
|
|
|
"--gpu_allow_growth", |
326
|
|
|
"-gr", |
327
|
|
|
help="Prevent TensorFlow from reserving all available GPU memory", |
328
|
|
|
default=False, |
329
|
|
|
) |
330
|
|
|
|
331
|
|
|
parser.add_argument( |
332
|
|
|
"--num_workers", |
333
|
|
|
help="Number of CPUs to be used, <= 0 means unlimited.", |
334
|
|
|
type=int, |
335
|
|
|
default=1, |
336
|
|
|
) |
337
|
|
|
|
338
|
|
|
parser.add_argument( |
339
|
|
|
"--ckpt_path", |
340
|
|
|
"-k", |
341
|
|
|
help="Path of checkpointed model to load", |
342
|
|
|
default="", |
343
|
|
|
type=str, |
344
|
|
|
required=True, |
345
|
|
|
) |
346
|
|
|
|
347
|
|
|
parser.add_argument( |
348
|
|
|
"--mode", |
349
|
|
|
"-m", |
350
|
|
|
help="Define the split of data to be used for prediction." |
351
|
|
|
"train or valid or test", |
352
|
|
|
type=str, |
353
|
|
|
default="test", |
354
|
|
|
required=True, |
355
|
|
|
) |
356
|
|
|
|
357
|
|
|
parser.add_argument( |
358
|
|
|
"--batch_size", "-b", help="Batch size for predictions", default=1, type=int |
359
|
|
|
) |
360
|
|
|
|
361
|
|
|
parser.add_argument( |
362
|
|
|
"--log_dir", help="Path of log directory.", default="logs", type=str |
363
|
|
|
) |
364
|
|
|
|
365
|
|
|
parser.add_argument( |
366
|
|
|
"--exp_name", "-n", help="Name of the experiment.", default="", type=str |
367
|
|
|
) |
368
|
|
|
|
369
|
|
|
parser.add_argument("--save_nifti", dest="nifti", action="store_true") |
370
|
|
|
parser.add_argument("--no_nifti", dest="nifti", action="store_false") |
371
|
|
|
parser.set_defaults(nifti=True) |
372
|
|
|
|
373
|
|
|
parser.add_argument("--save_png", dest="png", action="store_true") |
374
|
|
|
parser.add_argument("--no_png", dest="png", action="store_false") |
375
|
|
|
parser.set_defaults(png=False) |
376
|
|
|
|
377
|
|
|
parser.add_argument( |
378
|
|
|
"--config_path", |
379
|
|
|
"-c", |
380
|
|
|
help="Path of config, must end with .yaml. Can pass multiple paths.", |
381
|
|
|
type=str, |
382
|
|
|
nargs="*", |
383
|
|
|
default="", |
384
|
|
|
) |
385
|
|
|
|
386
|
|
|
args = parser.parse_args(args) |
387
|
|
|
|
388
|
|
|
predict( |
389
|
|
|
gpu=args.gpu, |
390
|
|
|
ckpt_path=args.ckpt_path, |
391
|
|
|
num_workers=args.num_workers, |
392
|
|
|
gpu_allow_growth=args.gpu_allow_growth, |
393
|
|
|
mode=args.mode, |
394
|
|
|
batch_size=args.batch_size, |
395
|
|
|
log_dir=args.log_dir, |
396
|
|
|
exp_name=args.exp_name, |
397
|
|
|
config_path=args.config_path, |
398
|
|
|
save_nifti=args.nifti, |
399
|
|
|
save_png=args.png, |
400
|
|
|
) |
401
|
|
|
|
402
|
|
|
|
403
|
|
|
if __name__ == "__main__": |
404
|
|
|
main() # pragma: no cover |
405
|
|
|
|