|
1
|
|
|
# coding=utf-8 |
|
2
|
|
|
|
|
3
|
|
|
""" |
|
4
|
|
|
Module to perform predictions on data using |
|
5
|
|
|
command line interface. |
|
6
|
|
|
""" |
|
7
|
|
|
|
|
8
|
|
|
import argparse |
|
9
|
|
|
import os |
|
10
|
|
|
import shutil |
|
11
|
|
|
from typing import Dict, List, Tuple, Union |
|
12
|
|
|
|
|
13
|
|
|
import numpy as np |
|
14
|
|
|
import tensorflow as tf |
|
15
|
|
|
|
|
16
|
|
|
import deepreg.config.parser as config_parser |
|
17
|
|
|
import deepreg.model.layer_util as layer_util |
|
18
|
|
|
import deepreg.model.optimizer as opt |
|
19
|
|
|
from deepreg import log |
|
20
|
|
|
from deepreg.callback import build_checkpoint_callback |
|
21
|
|
|
from deepreg.registry import REGISTRY |
|
22
|
|
|
from deepreg.util import ( |
|
23
|
|
|
build_dataset, |
|
24
|
|
|
build_log_dir, |
|
25
|
|
|
calculate_metrics, |
|
26
|
|
|
save_array, |
|
27
|
|
|
save_metric_dict, |
|
28
|
|
|
) |
|
29
|
|
|
|
|
30
|
|
|
logger = log.get(__name__) |
|
31
|
|
|
|
|
32
|
|
|
|
|
33
|
|
|
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]: |
|
34
|
|
|
""" |
|
35
|
|
|
Create directory for saving the paired data |
|
36
|
|
|
|
|
37
|
|
|
:param indices: indices of the pair, the last one is for label |
|
38
|
|
|
:param save_dir: directory of output |
|
39
|
|
|
:return: - save_dir, str, directory for saving the moving/fixed image |
|
40
|
|
|
- label_dir, str, directory for saving the rest outputs |
|
41
|
|
|
""" |
|
42
|
|
|
|
|
43
|
|
|
# cast indices to string and init directory name |
|
44
|
|
|
pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]]) |
|
45
|
|
|
pair_dir = os.path.join(save_dir, pair_index) |
|
46
|
|
|
os.makedirs(pair_dir, exist_ok=True) |
|
47
|
|
|
|
|
48
|
|
|
if indices[-1] >= 0: |
|
49
|
|
|
label_index = f"label_{indices[-1]}" |
|
50
|
|
|
label_dir = os.path.join(pair_dir, label_index) |
|
51
|
|
|
os.makedirs(label_dir, exist_ok=True) |
|
52
|
|
|
else: |
|
53
|
|
|
label_dir = pair_dir |
|
54
|
|
|
|
|
55
|
|
|
return pair_dir, label_dir |
|
56
|
|
|
|
|
57
|
|
|
|
|
58
|
|
|
def predict_on_dataset( |
|
59
|
|
|
dataset: tf.data.Dataset, |
|
60
|
|
|
fixed_grid_ref: tf.Tensor, |
|
61
|
|
|
model: tf.keras.Model, |
|
62
|
|
|
model_method: str, |
|
63
|
|
|
save_dir: str, |
|
64
|
|
|
save_nifti: bool, |
|
65
|
|
|
save_png: bool, |
|
66
|
|
|
): |
|
67
|
|
|
""" |
|
68
|
|
|
Function to predict results from a dataset from some model |
|
69
|
|
|
|
|
70
|
|
|
:param dataset: where data is stored |
|
71
|
|
|
:param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) |
|
72
|
|
|
:param model: model to be used for prediction |
|
73
|
|
|
:param model_method: ddf / dvf / affine / conditional |
|
74
|
|
|
:param save_dir: path to store dir |
|
75
|
|
|
:param save_nifti: if true, outputs will be saved in nifti format |
|
76
|
|
|
:param save_png: if true, outputs will be saved in png format |
|
77
|
|
|
""" |
|
78
|
|
|
# remove the save_dir in case it exists |
|
79
|
|
|
if os.path.exists(save_dir): |
|
80
|
|
|
shutil.rmtree(save_dir) # pragma: no cover |
|
81
|
|
|
|
|
82
|
|
|
sample_index_strs = [] |
|
83
|
|
|
metric_lists = [] |
|
84
|
|
|
for _, inputs in enumerate(dataset): |
|
85
|
|
|
batch_size = inputs[list(inputs.keys())[0]].shape[0] |
|
86
|
|
|
outputs = model.predict(x=inputs, batch_size=batch_size) |
|
87
|
|
|
indices, processed = model.postprocess(inputs=inputs, outputs=outputs) |
|
88
|
|
|
|
|
89
|
|
|
# convert to np arrays |
|
90
|
|
|
indices = indices.numpy() |
|
91
|
|
|
processed = { |
|
92
|
|
|
k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2]) |
|
93
|
|
|
for k, v in processed.items() |
|
94
|
|
|
} |
|
95
|
|
|
|
|
96
|
|
|
# save images of inputs and outputs |
|
97
|
|
|
for sample_index in range(batch_size): |
|
98
|
|
|
# save label independent tensors under pair_dir, otherwise under label_dir |
|
99
|
|
|
|
|
100
|
|
|
# init output path |
|
101
|
|
|
indices_i = indices[sample_index, :].astype(int).tolist() |
|
102
|
|
|
pair_dir, label_dir = build_pair_output_path( |
|
103
|
|
|
indices=indices_i, save_dir=save_dir |
|
104
|
|
|
) |
|
105
|
|
|
|
|
106
|
|
|
for name, (arr, normalize, on_label) in processed.items(): |
|
107
|
|
|
if name == "theta": |
|
108
|
|
|
np.savetxt( |
|
109
|
|
|
fname=os.path.join(pair_dir, "affine.txt"), |
|
110
|
|
|
X=arr[sample_index, :, :], |
|
111
|
|
|
delimiter=",", |
|
112
|
|
|
) |
|
113
|
|
|
continue |
|
114
|
|
|
|
|
115
|
|
|
arr_save_dir = label_dir if on_label else pair_dir |
|
116
|
|
|
save_array( |
|
117
|
|
|
save_dir=arr_save_dir, |
|
118
|
|
|
arr=arr[sample_index, :, :, :], |
|
119
|
|
|
name=name, |
|
120
|
|
|
normalize=normalize, # label's value is already in [0, 1] |
|
121
|
|
|
save_nifti=save_nifti, |
|
122
|
|
|
save_png=save_png, |
|
123
|
|
|
overwrite=arr_save_dir == label_dir, |
|
124
|
|
|
) |
|
125
|
|
|
|
|
126
|
|
|
# calculate metric |
|
127
|
|
|
sample_index_str = "_".join([str(x) for x in indices_i]) |
|
128
|
|
|
if sample_index_str in sample_index_strs: # pragma: no cover |
|
129
|
|
|
raise ValueError( |
|
130
|
|
|
"Sample is repeated, maybe the dataset has been repeated." |
|
131
|
|
|
) |
|
132
|
|
|
sample_index_strs.append(sample_index_str) |
|
133
|
|
|
|
|
134
|
|
|
metric = calculate_metrics( |
|
135
|
|
|
fixed_image=processed["fixed_image"][0], |
|
136
|
|
|
fixed_label=processed["fixed_label"][0] if model.labeled else None, |
|
137
|
|
|
pred_fixed_image=processed["pred_fixed_image"][0], |
|
138
|
|
|
pred_fixed_label=processed["pred_fixed_label"][0] |
|
139
|
|
|
if model.labeled |
|
140
|
|
|
else None, |
|
141
|
|
|
fixed_grid_ref=fixed_grid_ref, |
|
142
|
|
|
sample_index=sample_index, |
|
143
|
|
|
) |
|
144
|
|
|
metric["pair_index"] = indices_i[:-1] |
|
145
|
|
|
metric["label_index"] = indices_i[-1] |
|
146
|
|
|
metric_lists.append(metric) |
|
147
|
|
|
|
|
148
|
|
|
# save metric |
|
149
|
|
|
save_metric_dict(save_dir=save_dir, metrics=metric_lists) |
|
150
|
|
|
|
|
151
|
|
|
|
|
152
|
|
|
def build_config( |
|
153
|
|
|
config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str |
|
154
|
|
|
) -> Tuple[Dict, str, str]: |
|
155
|
|
|
""" |
|
156
|
|
|
Function to create new directory to log directory to store results. |
|
157
|
|
|
|
|
158
|
|
|
:param config_path: path of configuration files. |
|
159
|
|
|
:param log_dir: path of the log directory. |
|
160
|
|
|
:param exp_name: experiment name. |
|
161
|
|
|
:param ckpt_path: path where model is stored. |
|
162
|
|
|
:return: - config, configuration dictionary. |
|
163
|
|
|
- exp_name, path of the directory for saving outputs. |
|
164
|
|
|
""" |
|
165
|
|
|
|
|
166
|
|
|
# init log directory |
|
167
|
|
|
log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name) |
|
168
|
|
|
|
|
169
|
|
|
# replace the ~ with user home path |
|
170
|
|
|
ckpt_path = os.path.expanduser(ckpt_path) |
|
171
|
|
|
|
|
172
|
|
|
# load config |
|
173
|
|
|
if config_path == "": |
|
174
|
|
|
# use default config, which should be provided in the log folder |
|
175
|
|
|
config = config_parser.load_configs( |
|
176
|
|
|
"/".join(ckpt_path.split("/")[:-2]) + "/config.yaml" |
|
177
|
|
|
) |
|
178
|
|
|
else: |
|
179
|
|
|
# use customized config |
|
180
|
|
|
logger.warning( |
|
181
|
|
|
"Using customized configuration. " |
|
182
|
|
|
"The code might break if the config doesn't match the saved model." |
|
183
|
|
|
) |
|
184
|
|
|
config = config_parser.load_configs(config_path) |
|
185
|
|
|
return config, log_dir, ckpt_path |
|
186
|
|
|
|
|
187
|
|
|
|
|
188
|
|
|
def predict( |
|
189
|
|
|
gpu: str, |
|
190
|
|
|
ckpt_path: str, |
|
191
|
|
|
mode: str, |
|
192
|
|
|
batch_size: int, |
|
193
|
|
|
exp_name: str, |
|
194
|
|
|
config_path: Union[str, List[str]], |
|
195
|
|
|
num_workers: int = 1, |
|
196
|
|
|
gpu_allow_growth: bool = True, |
|
197
|
|
|
save_nifti: bool = True, |
|
198
|
|
|
save_png: bool = True, |
|
199
|
|
|
log_dir: str = "logs", |
|
200
|
|
|
): |
|
201
|
|
|
""" |
|
202
|
|
|
Function to predict some metrics from the saved model and logging results. |
|
203
|
|
|
|
|
204
|
|
|
:param gpu: which env gpu to use. |
|
205
|
|
|
:param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x. |
|
206
|
|
|
:param mode: train / valid / test, to define which split of dataset to be evaluated. |
|
207
|
|
|
:param batch_size: int, batch size to perform predictions. |
|
208
|
|
|
:param exp_name: name of the experiment. |
|
209
|
|
|
:param config_path: to overwrite the default config. |
|
210
|
|
|
:param num_workers: number of cpu cores to be used, <=0 means not limited. |
|
211
|
|
|
:param gpu_allow_growth: whether to allocate whole GPU memory for training. |
|
212
|
|
|
:param save_nifti: if true, outputs will be saved in nifti format. |
|
213
|
|
|
:param save_png: if true, outputs will be saved in png format. |
|
214
|
|
|
:param log_dir: path of the log directory. |
|
215
|
|
|
""" |
|
216
|
|
|
|
|
217
|
|
|
# env vars |
|
218
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
|
219
|
|
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true" |
|
220
|
|
View Code Duplication |
if num_workers <= 0: # pragma: no cover |
|
|
|
|
|
|
221
|
|
|
logger.info( |
|
222
|
|
|
"Limiting CPU usage by setting environment variables " |
|
223
|
|
|
"OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. " |
|
224
|
|
|
"This may slow down the prediction. " |
|
225
|
|
|
"Please use --num_workers flag to modify the behavior. " |
|
226
|
|
|
"Setting to 0 or negative values will remove the limitation.", |
|
227
|
|
|
num_workers, |
|
228
|
|
|
) |
|
229
|
|
|
# limit CPU usage |
|
230
|
|
|
# https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232 |
|
231
|
|
|
os.environ["OMP_NUM_THREADS"] = str(num_workers) |
|
232
|
|
|
os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers) |
|
233
|
|
|
os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers) |
|
234
|
|
|
|
|
235
|
|
|
# load config |
|
236
|
|
|
config, log_dir, ckpt_path = build_config( |
|
237
|
|
|
config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path |
|
238
|
|
|
) |
|
239
|
|
|
config["train"]["preprocess"]["batch_size"] = batch_size |
|
240
|
|
|
|
|
241
|
|
|
# data |
|
242
|
|
|
data_loader, dataset, _ = build_dataset( |
|
243
|
|
|
dataset_config=config["dataset"], |
|
244
|
|
|
preprocess_config=config["train"]["preprocess"], |
|
245
|
|
|
mode=mode, |
|
246
|
|
|
training=False, |
|
247
|
|
|
repeat=False, |
|
248
|
|
|
) |
|
249
|
|
|
assert data_loader is not None |
|
250
|
|
|
|
|
251
|
|
|
# use strategy to support multiple GPUs |
|
252
|
|
|
# the network is mirrored in each GPU so that we can use larger batch size |
|
253
|
|
|
# https://www.tensorflow.org/guide/distributed_training |
|
254
|
|
|
# only model, optimizer and metrics need to be defined inside the strategy |
|
255
|
|
|
num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) |
|
256
|
|
|
if num_devices > 1: # pragma: no cover |
|
257
|
|
|
strategy = tf.distribute.MirroredStrategy() |
|
258
|
|
|
if batch_size % num_devices != 0: |
|
259
|
|
|
raise ValueError( |
|
260
|
|
|
f"batch size {batch_size} can not be divided evenly " |
|
261
|
|
|
f"by the number of devices." |
|
262
|
|
|
) |
|
263
|
|
|
else: |
|
264
|
|
|
strategy = tf.distribute.get_strategy() |
|
265
|
|
|
with strategy.scope(): |
|
266
|
|
|
model: tf.keras.Model = REGISTRY.build_model( |
|
267
|
|
|
config=dict( |
|
268
|
|
|
name=config["train"]["method"], |
|
269
|
|
|
moving_image_size=data_loader.moving_image_shape, |
|
270
|
|
|
fixed_image_size=data_loader.fixed_image_shape, |
|
271
|
|
|
index_size=data_loader.num_indices, |
|
272
|
|
|
labeled=config["dataset"]["labeled"], |
|
273
|
|
|
batch_size=batch_size, |
|
274
|
|
|
config=config["train"], |
|
275
|
|
|
) |
|
276
|
|
|
) |
|
277
|
|
|
optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) |
|
278
|
|
|
model.compile(optimizer=optimizer) |
|
279
|
|
|
model.plot_model(output_dir=log_dir) |
|
280
|
|
|
|
|
281
|
|
|
# load weights |
|
282
|
|
|
if ckpt_path.endswith(".ckpt"): |
|
283
|
|
|
# for ckpt from tf.keras.callbacks.ModelCheckpoint |
|
284
|
|
|
# skip warnings because of optimizers |
|
285
|
|
|
# https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object |
|
286
|
|
|
model.load_weights(ckpt_path).expect_partial() # pragma: no cover |
|
287
|
|
|
else: |
|
288
|
|
|
# for ckpts from ckpt manager callback |
|
289
|
|
|
_, _ = build_checkpoint_callback( |
|
290
|
|
|
model=model, |
|
291
|
|
|
dataset=dataset, |
|
292
|
|
|
log_dir=log_dir, |
|
293
|
|
|
save_period=config["train"]["save_period"], |
|
294
|
|
|
ckpt_path=ckpt_path, |
|
295
|
|
|
) |
|
296
|
|
|
|
|
297
|
|
|
# predict |
|
298
|
|
|
fixed_grid_ref = tf.expand_dims( |
|
299
|
|
|
layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0 |
|
300
|
|
|
) # shape = (1, f_dim1, f_dim2, f_dim3, 3) |
|
301
|
|
|
predict_on_dataset( |
|
302
|
|
|
dataset=dataset, |
|
303
|
|
|
fixed_grid_ref=fixed_grid_ref, |
|
304
|
|
|
model=model, |
|
305
|
|
|
model_method=config["train"]["method"], |
|
306
|
|
|
save_dir=os.path.join(log_dir, "test"), |
|
307
|
|
|
save_nifti=save_nifti, |
|
308
|
|
|
save_png=save_png, |
|
309
|
|
|
) |
|
310
|
|
|
|
|
311
|
|
|
# close the opened files in data loaders |
|
312
|
|
|
data_loader.close() |
|
313
|
|
|
|
|
314
|
|
|
|
|
315
|
|
|
def main(args=None): |
|
316
|
|
|
""" |
|
317
|
|
|
Entry point for predict script. |
|
318
|
|
|
|
|
319
|
|
|
:param args: |
|
320
|
|
|
""" |
|
321
|
|
|
parser = argparse.ArgumentParser() |
|
322
|
|
|
|
|
323
|
|
|
parser.add_argument( |
|
324
|
|
|
"--gpu", |
|
325
|
|
|
"-g", |
|
326
|
|
|
help="GPU index for training." |
|
327
|
|
|
'-g "" for using CPU' |
|
328
|
|
|
'-g "0" for using GPU 0' |
|
329
|
|
|
'-g "0,1" for using GPU 0 and 1.', |
|
330
|
|
|
type=str, |
|
331
|
|
|
required=True, |
|
332
|
|
|
) |
|
333
|
|
|
|
|
334
|
|
|
parser.add_argument( |
|
335
|
|
|
"--gpu_allow_growth", |
|
336
|
|
|
"-gr", |
|
337
|
|
|
help="Prevent TensorFlow from reserving all available GPU memory", |
|
338
|
|
|
default=False, |
|
339
|
|
|
) |
|
340
|
|
|
|
|
341
|
|
|
parser.add_argument( |
|
342
|
|
|
"--num_workers", |
|
343
|
|
|
help="Number of CPUs to be used, <= 0 means unlimited.", |
|
344
|
|
|
type=int, |
|
345
|
|
|
default=1, |
|
346
|
|
|
) |
|
347
|
|
|
|
|
348
|
|
|
parser.add_argument( |
|
349
|
|
|
"--ckpt_path", |
|
350
|
|
|
"-k", |
|
351
|
|
|
help="Path of checkpointed model to load", |
|
352
|
|
|
default="", |
|
353
|
|
|
type=str, |
|
354
|
|
|
required=True, |
|
355
|
|
|
) |
|
356
|
|
|
|
|
357
|
|
|
parser.add_argument( |
|
358
|
|
|
"--mode", |
|
359
|
|
|
"-m", |
|
360
|
|
|
help="Define the split of data to be used for prediction." |
|
361
|
|
|
"train or valid or test", |
|
362
|
|
|
type=str, |
|
363
|
|
|
default="test", |
|
364
|
|
|
required=True, |
|
365
|
|
|
) |
|
366
|
|
|
|
|
367
|
|
|
parser.add_argument( |
|
368
|
|
|
"--batch_size", "-b", help="Batch size for predictions", default=1, type=int |
|
369
|
|
|
) |
|
370
|
|
|
|
|
371
|
|
|
parser.add_argument( |
|
372
|
|
|
"--log_dir", help="Path of log directory.", default="logs", type=str |
|
373
|
|
|
) |
|
374
|
|
|
|
|
375
|
|
|
parser.add_argument( |
|
376
|
|
|
"--exp_name", "-n", help="Name of the experiment.", default="", type=str |
|
377
|
|
|
) |
|
378
|
|
|
|
|
379
|
|
|
parser.add_argument("--save_nifti", dest="nifti", action="store_true") |
|
380
|
|
|
parser.add_argument("--no_nifti", dest="nifti", action="store_false") |
|
381
|
|
|
parser.set_defaults(nifti=True) |
|
382
|
|
|
|
|
383
|
|
|
parser.add_argument("--save_png", dest="png", action="store_true") |
|
384
|
|
|
parser.add_argument("--no_png", dest="png", action="store_false") |
|
385
|
|
|
parser.set_defaults(png=False) |
|
386
|
|
|
|
|
387
|
|
|
parser.add_argument( |
|
388
|
|
|
"--config_path", |
|
389
|
|
|
"-c", |
|
390
|
|
|
help="Path of config, must end with .yaml. Can pass multiple paths.", |
|
391
|
|
|
type=str, |
|
392
|
|
|
nargs="*", |
|
393
|
|
|
default="", |
|
394
|
|
|
) |
|
395
|
|
|
|
|
396
|
|
|
args = parser.parse_args(args) |
|
397
|
|
|
|
|
398
|
|
|
predict( |
|
399
|
|
|
gpu=args.gpu, |
|
400
|
|
|
ckpt_path=args.ckpt_path, |
|
401
|
|
|
num_workers=args.num_workers, |
|
402
|
|
|
gpu_allow_growth=args.gpu_allow_growth, |
|
403
|
|
|
mode=args.mode, |
|
404
|
|
|
batch_size=args.batch_size, |
|
405
|
|
|
log_dir=args.log_dir, |
|
406
|
|
|
exp_name=args.exp_name, |
|
407
|
|
|
config_path=args.config_path, |
|
408
|
|
|
save_nifti=args.nifti, |
|
409
|
|
|
save_png=args.png, |
|
410
|
|
|
) |
|
411
|
|
|
|
|
412
|
|
|
|
|
413
|
|
|
if __name__ == "__main__": |
|
414
|
|
|
main() # pragma: no cover |
|
415
|
|
|
|