1
|
|
|
# coding=utf-8 |
2
|
|
|
|
3
|
|
|
""" |
4
|
|
|
Module to perform predictions on data using |
5
|
|
|
command line interface. |
6
|
|
|
""" |
7
|
|
|
|
8
|
|
|
import argparse |
9
|
|
|
import os |
10
|
|
|
import shutil |
11
|
|
|
from typing import Dict, List, Tuple, Union |
12
|
|
|
|
13
|
|
|
import numpy as np |
14
|
|
|
import tensorflow as tf |
15
|
|
|
|
16
|
|
|
import deepreg.config.parser as config_parser |
17
|
|
|
import deepreg.model.layer_util as layer_util |
18
|
|
|
import deepreg.model.optimizer as opt |
19
|
|
|
from deepreg import log |
20
|
|
|
from deepreg.callback import build_checkpoint_callback |
21
|
|
|
from deepreg.registry import REGISTRY |
22
|
|
|
from deepreg.util import ( |
23
|
|
|
build_dataset, |
24
|
|
|
build_log_dir, |
25
|
|
|
calculate_metrics, |
26
|
|
|
save_array, |
27
|
|
|
save_metric_dict, |
28
|
|
|
) |
29
|
|
|
|
30
|
|
|
logger = log.get(__name__) |
31
|
|
|
|
32
|
|
|
|
33
|
|
|
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]: |
34
|
|
|
""" |
35
|
|
|
Create directory for saving the paired data |
36
|
|
|
|
37
|
|
|
:param indices: indices of the pair, the last one is for label |
38
|
|
|
:param save_dir: directory of output |
39
|
|
|
:return: - save_dir, str, directory for saving the moving/fixed image |
40
|
|
|
- label_dir, str, directory for saving the rest outputs |
41
|
|
|
""" |
42
|
|
|
|
43
|
|
|
# cast indices to string and init directory name |
44
|
|
|
pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]]) |
45
|
|
|
pair_dir = os.path.join(save_dir, pair_index) |
46
|
|
|
os.makedirs(pair_dir, exist_ok=True) |
47
|
|
|
|
48
|
|
|
if indices[-1] >= 0: |
49
|
|
|
label_index = f"label_{indices[-1]}" |
50
|
|
|
label_dir = os.path.join(pair_dir, label_index) |
51
|
|
|
os.makedirs(label_dir, exist_ok=True) |
52
|
|
|
else: |
53
|
|
|
label_dir = pair_dir |
54
|
|
|
|
55
|
|
|
return pair_dir, label_dir |
56
|
|
|
|
57
|
|
|
|
58
|
|
|
def predict_on_dataset( |
59
|
|
|
dataset: tf.data.Dataset, |
60
|
|
|
fixed_grid_ref: tf.Tensor, |
61
|
|
|
model: tf.keras.Model, |
62
|
|
|
model_method: str, |
63
|
|
|
save_dir: str, |
64
|
|
|
save_nifti: bool, |
65
|
|
|
save_png: bool, |
66
|
|
|
): |
67
|
|
|
""" |
68
|
|
|
Function to predict results from a dataset from some model |
69
|
|
|
|
70
|
|
|
:param dataset: where data is stored |
71
|
|
|
:param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) |
72
|
|
|
:param model: model to be used for prediction |
73
|
|
|
:param model_method: ddf / dvf / affine / conditional |
74
|
|
|
:param save_dir: path to store dir |
75
|
|
|
:param save_nifti: if true, outputs will be saved in nifti format |
76
|
|
|
:param save_png: if true, outputs will be saved in png format |
77
|
|
|
""" |
78
|
|
|
# remove the save_dir in case it exists |
79
|
|
|
if os.path.exists(save_dir): |
80
|
|
|
shutil.rmtree(save_dir) # pragma: no cover |
81
|
|
|
|
82
|
|
|
sample_index_strs = [] |
83
|
|
|
metric_lists = [] |
84
|
|
|
for _, inputs in enumerate(dataset): |
85
|
|
|
batch_size = inputs[list(inputs.keys())[0]].shape[0] |
86
|
|
|
outputs = model.predict(x=inputs, batch_size=batch_size) |
87
|
|
|
indices, processed = model.postprocess(inputs=inputs, outputs=outputs) |
88
|
|
|
|
89
|
|
|
# convert to np arrays |
90
|
|
|
indices = indices.numpy() |
91
|
|
|
processed = { |
92
|
|
|
k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2]) |
93
|
|
|
for k, v in processed.items() |
94
|
|
|
} |
95
|
|
|
|
96
|
|
|
# save images of inputs and outputs |
97
|
|
|
for sample_index in range(batch_size): |
98
|
|
|
# save label independent tensors under pair_dir, otherwise under label_dir |
99
|
|
|
|
100
|
|
|
# init output path |
101
|
|
|
indices_i = indices[sample_index, :].astype(int).tolist() |
102
|
|
|
pair_dir, label_dir = build_pair_output_path( |
103
|
|
|
indices=indices_i, save_dir=save_dir |
104
|
|
|
) |
105
|
|
|
|
106
|
|
|
for name, (arr, normalize, on_label) in processed.items(): |
107
|
|
|
if name == "theta": |
108
|
|
|
np.savetxt( |
109
|
|
|
fname=os.path.join(pair_dir, "affine.txt"), |
110
|
|
|
X=arr[sample_index, :, :], |
111
|
|
|
delimiter=",", |
112
|
|
|
) |
113
|
|
|
continue |
114
|
|
|
|
115
|
|
|
arr_save_dir = label_dir if on_label else pair_dir |
116
|
|
|
save_array( |
117
|
|
|
save_dir=arr_save_dir, |
118
|
|
|
arr=arr[sample_index, :, :, :], |
119
|
|
|
name=name, |
120
|
|
|
normalize=normalize, # label's value is already in [0, 1] |
121
|
|
|
save_nifti=save_nifti, |
122
|
|
|
save_png=save_png, |
123
|
|
|
overwrite=arr_save_dir == label_dir, |
124
|
|
|
) |
125
|
|
|
|
126
|
|
|
# calculate metric |
127
|
|
|
sample_index_str = "_".join([str(x) for x in indices_i]) |
128
|
|
|
if sample_index_str in sample_index_strs: # pragma: no cover |
129
|
|
|
raise ValueError( |
130
|
|
|
"Sample is repeated, maybe the dataset has been repeated." |
131
|
|
|
) |
132
|
|
|
sample_index_strs.append(sample_index_str) |
133
|
|
|
|
134
|
|
|
metric = calculate_metrics( |
135
|
|
|
fixed_image=processed["fixed_image"][0], |
136
|
|
|
fixed_label=processed["fixed_label"][0] if model.labeled else None, |
137
|
|
|
pred_fixed_image=processed["pred_fixed_image"][0], |
138
|
|
|
pred_fixed_label=processed["pred_fixed_label"][0] |
139
|
|
|
if model.labeled |
140
|
|
|
else None, |
141
|
|
|
fixed_grid_ref=fixed_grid_ref, |
142
|
|
|
sample_index=sample_index, |
143
|
|
|
) |
144
|
|
|
metric["pair_index"] = indices_i[:-1] |
145
|
|
|
metric["label_index"] = indices_i[-1] |
146
|
|
|
metric_lists.append(metric) |
147
|
|
|
|
148
|
|
|
# save metric |
149
|
|
|
save_metric_dict(save_dir=save_dir, metrics=metric_lists) |
150
|
|
|
|
151
|
|
|
|
152
|
|
|
def build_config( |
153
|
|
|
config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str |
154
|
|
|
) -> Tuple[Dict, str, str]: |
155
|
|
|
""" |
156
|
|
|
Function to create new directory to log directory to store results. |
157
|
|
|
|
158
|
|
|
:param config_path: path of configuration files. |
159
|
|
|
:param log_dir: path of the log directory. |
160
|
|
|
:param exp_name: experiment name. |
161
|
|
|
:param ckpt_path: path where model is stored. |
162
|
|
|
:return: - config, configuration dictionary. |
163
|
|
|
- exp_name, path of the directory for saving outputs. |
164
|
|
|
""" |
165
|
|
|
|
166
|
|
|
# init log directory |
167
|
|
|
log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name) |
168
|
|
|
|
169
|
|
|
# replace the ~ with user home path |
170
|
|
|
ckpt_path = os.path.expanduser(ckpt_path) |
171
|
|
|
|
172
|
|
|
# load config |
173
|
|
|
if config_path == "": |
174
|
|
|
# use default config, which should be provided in the log folder |
175
|
|
|
config = config_parser.load_configs( |
176
|
|
|
"/".join(ckpt_path.split("/")[:-2]) + "/config.yaml" |
177
|
|
|
) |
178
|
|
|
else: |
179
|
|
|
# use customized config |
180
|
|
|
logger.warning( |
181
|
|
|
"Using customized configuration. " |
182
|
|
|
"The code might break if the config doesn't match the saved model." |
183
|
|
|
) |
184
|
|
|
config = config_parser.load_configs(config_path) |
185
|
|
|
return config, log_dir, ckpt_path |
186
|
|
|
|
187
|
|
|
|
188
|
|
|
def predict( |
189
|
|
|
gpu: str, |
190
|
|
|
ckpt_path: str, |
191
|
|
|
mode: str, |
192
|
|
|
batch_size: int, |
193
|
|
|
exp_name: str, |
194
|
|
|
config_path: Union[str, List[str]], |
195
|
|
|
num_workers: int = 1, |
196
|
|
|
gpu_allow_growth: bool = True, |
197
|
|
|
save_nifti: bool = True, |
198
|
|
|
save_png: bool = True, |
199
|
|
|
log_dir: str = "logs", |
200
|
|
|
): |
201
|
|
|
""" |
202
|
|
|
Function to predict some metrics from the saved model and logging results. |
203
|
|
|
|
204
|
|
|
:param gpu: which env gpu to use. |
205
|
|
|
:param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x. |
206
|
|
|
:param mode: train / valid / test, to define which split of dataset to be evaluated. |
207
|
|
|
:param batch_size: int, batch size to perform predictions. |
208
|
|
|
:param exp_name: name of the experiment. |
209
|
|
|
:param config_path: to overwrite the default config. |
210
|
|
|
:param num_workers: number of cpu cores to be used, <=0 means not limited. |
211
|
|
|
:param gpu_allow_growth: whether to allocate whole GPU memory for training. |
212
|
|
|
:param save_nifti: if true, outputs will be saved in nifti format. |
213
|
|
|
:param save_png: if true, outputs will be saved in png format. |
214
|
|
|
:param log_dir: path of the log directory. |
215
|
|
|
""" |
216
|
|
|
|
217
|
|
|
# env vars |
218
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
219
|
|
|
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true" |
220
|
|
View Code Duplication |
if num_workers <= 0: # pragma: no cover |
|
|
|
|
221
|
|
|
logger.info( |
222
|
|
|
"Limiting CPU usage by setting environment variables " |
223
|
|
|
"OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. " |
224
|
|
|
"This may slow down the prediction. " |
225
|
|
|
"Please use --num_workers flag to modify the behavior. " |
226
|
|
|
"Setting to 0 or negative values will remove the limitation.", |
227
|
|
|
num_workers, |
228
|
|
|
) |
229
|
|
|
# limit CPU usage |
230
|
|
|
# https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232 |
231
|
|
|
os.environ["OMP_NUM_THREADS"] = str(num_workers) |
232
|
|
|
os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers) |
233
|
|
|
os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers) |
234
|
|
|
|
235
|
|
|
# load config |
236
|
|
|
config, log_dir, ckpt_path = build_config( |
237
|
|
|
config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path |
238
|
|
|
) |
239
|
|
|
config["train"]["preprocess"]["batch_size"] = batch_size |
240
|
|
|
|
241
|
|
|
# data |
242
|
|
|
data_loader, dataset, _ = build_dataset( |
243
|
|
|
dataset_config=config["dataset"], |
244
|
|
|
preprocess_config=config["train"]["preprocess"], |
245
|
|
|
mode=mode, |
246
|
|
|
training=False, |
247
|
|
|
repeat=False, |
248
|
|
|
) |
249
|
|
|
assert data_loader is not None |
250
|
|
|
|
251
|
|
|
# use strategy to support multiple GPUs |
252
|
|
|
# the network is mirrored in each GPU so that we can use larger batch size |
253
|
|
|
# https://www.tensorflow.org/guide/distributed_training |
254
|
|
|
# only model, optimizer and metrics need to be defined inside the strategy |
255
|
|
|
num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) |
256
|
|
|
if num_devices > 1: # pragma: no cover |
257
|
|
|
strategy = tf.distribute.MirroredStrategy() |
258
|
|
|
if batch_size % num_devices != 0: |
259
|
|
|
raise ValueError( |
260
|
|
|
f"batch size {batch_size} can not be divided evenly " |
261
|
|
|
f"by the number of devices." |
262
|
|
|
) |
263
|
|
|
else: |
264
|
|
|
strategy = tf.distribute.get_strategy() |
265
|
|
|
with strategy.scope(): |
266
|
|
|
model: tf.keras.Model = REGISTRY.build_model( |
267
|
|
|
config=dict( |
268
|
|
|
name=config["train"]["method"], |
269
|
|
|
moving_image_size=data_loader.moving_image_shape, |
270
|
|
|
fixed_image_size=data_loader.fixed_image_shape, |
271
|
|
|
index_size=data_loader.num_indices, |
272
|
|
|
labeled=config["dataset"]["labeled"], |
273
|
|
|
batch_size=batch_size, |
274
|
|
|
config=config["train"], |
275
|
|
|
) |
276
|
|
|
) |
277
|
|
|
optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) |
278
|
|
|
model.compile(optimizer=optimizer) |
279
|
|
|
model.plot_model(output_dir=log_dir) |
280
|
|
|
|
281
|
|
|
# load weights |
282
|
|
|
if ckpt_path.endswith(".ckpt"): |
283
|
|
|
# for ckpt from tf.keras.callbacks.ModelCheckpoint |
284
|
|
|
# skip warnings because of optimizers |
285
|
|
|
# https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object |
286
|
|
|
model.load_weights(ckpt_path).expect_partial() # pragma: no cover |
287
|
|
|
else: |
288
|
|
|
# for ckpts from ckpt manager callback |
289
|
|
|
_, _ = build_checkpoint_callback( |
290
|
|
|
model=model, |
291
|
|
|
dataset=dataset, |
292
|
|
|
log_dir=log_dir, |
293
|
|
|
save_period=config["train"]["save_period"], |
294
|
|
|
ckpt_path=ckpt_path, |
295
|
|
|
) |
296
|
|
|
|
297
|
|
|
# predict |
298
|
|
|
fixed_grid_ref = tf.expand_dims( |
299
|
|
|
layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0 |
300
|
|
|
) # shape = (1, f_dim1, f_dim2, f_dim3, 3) |
301
|
|
|
predict_on_dataset( |
302
|
|
|
dataset=dataset, |
303
|
|
|
fixed_grid_ref=fixed_grid_ref, |
304
|
|
|
model=model, |
305
|
|
|
model_method=config["train"]["method"], |
306
|
|
|
save_dir=os.path.join(log_dir, "test"), |
307
|
|
|
save_nifti=save_nifti, |
308
|
|
|
save_png=save_png, |
309
|
|
|
) |
310
|
|
|
|
311
|
|
|
# close the opened files in data loaders |
312
|
|
|
data_loader.close() |
313
|
|
|
|
314
|
|
|
|
315
|
|
|
def main(args=None): |
316
|
|
|
""" |
317
|
|
|
Entry point for predict script. |
318
|
|
|
|
319
|
|
|
:param args: |
320
|
|
|
""" |
321
|
|
|
parser = argparse.ArgumentParser() |
322
|
|
|
|
323
|
|
|
parser.add_argument( |
324
|
|
|
"--gpu", |
325
|
|
|
"-g", |
326
|
|
|
help="GPU index for training." |
327
|
|
|
'-g "" for using CPU' |
328
|
|
|
'-g "0" for using GPU 0' |
329
|
|
|
'-g "0,1" for using GPU 0 and 1.', |
330
|
|
|
type=str, |
331
|
|
|
required=True, |
332
|
|
|
) |
333
|
|
|
|
334
|
|
|
parser.add_argument( |
335
|
|
|
"--gpu_allow_growth", |
336
|
|
|
"-gr", |
337
|
|
|
help="Prevent TensorFlow from reserving all available GPU memory", |
338
|
|
|
default=False, |
339
|
|
|
) |
340
|
|
|
|
341
|
|
|
parser.add_argument( |
342
|
|
|
"--num_workers", |
343
|
|
|
help="Number of CPUs to be used, <= 0 means unlimited.", |
344
|
|
|
type=int, |
345
|
|
|
default=1, |
346
|
|
|
) |
347
|
|
|
|
348
|
|
|
parser.add_argument( |
349
|
|
|
"--ckpt_path", |
350
|
|
|
"-k", |
351
|
|
|
help="Path of checkpointed model to load", |
352
|
|
|
default="", |
353
|
|
|
type=str, |
354
|
|
|
required=True, |
355
|
|
|
) |
356
|
|
|
|
357
|
|
|
parser.add_argument( |
358
|
|
|
"--mode", |
359
|
|
|
"-m", |
360
|
|
|
help="Define the split of data to be used for prediction." |
361
|
|
|
"train or valid or test", |
362
|
|
|
type=str, |
363
|
|
|
default="test", |
364
|
|
|
required=True, |
365
|
|
|
) |
366
|
|
|
|
367
|
|
|
parser.add_argument( |
368
|
|
|
"--batch_size", "-b", help="Batch size for predictions", default=1, type=int |
369
|
|
|
) |
370
|
|
|
|
371
|
|
|
parser.add_argument( |
372
|
|
|
"--log_dir", help="Path of log directory.", default="logs", type=str |
373
|
|
|
) |
374
|
|
|
|
375
|
|
|
parser.add_argument( |
376
|
|
|
"--exp_name", "-n", help="Name of the experiment.", default="", type=str |
377
|
|
|
) |
378
|
|
|
|
379
|
|
|
parser.add_argument("--save_nifti", dest="nifti", action="store_true") |
380
|
|
|
parser.add_argument("--no_nifti", dest="nifti", action="store_false") |
381
|
|
|
parser.set_defaults(nifti=True) |
382
|
|
|
|
383
|
|
|
parser.add_argument("--save_png", dest="png", action="store_true") |
384
|
|
|
parser.add_argument("--no_png", dest="png", action="store_false") |
385
|
|
|
parser.set_defaults(png=False) |
386
|
|
|
|
387
|
|
|
parser.add_argument( |
388
|
|
|
"--config_path", |
389
|
|
|
"-c", |
390
|
|
|
help="Path of config, must end with .yaml. Can pass multiple paths.", |
391
|
|
|
type=str, |
392
|
|
|
nargs="*", |
393
|
|
|
default="", |
394
|
|
|
) |
395
|
|
|
|
396
|
|
|
args = parser.parse_args(args) |
397
|
|
|
|
398
|
|
|
predict( |
399
|
|
|
gpu=args.gpu, |
400
|
|
|
ckpt_path=args.ckpt_path, |
401
|
|
|
num_workers=args.num_workers, |
402
|
|
|
gpu_allow_growth=args.gpu_allow_growth, |
403
|
|
|
mode=args.mode, |
404
|
|
|
batch_size=args.batch_size, |
405
|
|
|
log_dir=args.log_dir, |
406
|
|
|
exp_name=args.exp_name, |
407
|
|
|
config_path=args.config_path, |
408
|
|
|
save_nifti=args.nifti, |
409
|
|
|
save_png=args.png, |
410
|
|
|
) |
411
|
|
|
|
412
|
|
|
|
413
|
|
|
if __name__ == "__main__": |
414
|
|
|
main() # pragma: no cover |
415
|
|
|
|