DeepRegNet /
DeepReg
| 1 | # coding=utf-8 |
||
| 2 | |||
| 3 | """ |
||
| 4 | Module to perform predictions on data using |
||
| 5 | command line interface. |
||
| 6 | """ |
||
| 7 | |||
| 8 | import argparse |
||
| 9 | import os |
||
| 10 | import shutil |
||
| 11 | from typing import Dict, List, Tuple, Union |
||
| 12 | |||
| 13 | import numpy as np |
||
| 14 | import tensorflow as tf |
||
| 15 | |||
| 16 | import deepreg.config.parser as config_parser |
||
| 17 | import deepreg.model.layer_util as layer_util |
||
| 18 | import deepreg.model.optimizer as opt |
||
| 19 | from deepreg import log |
||
| 20 | from deepreg.callback import build_checkpoint_callback |
||
| 21 | from deepreg.registry import REGISTRY |
||
| 22 | from deepreg.util import ( |
||
| 23 | build_dataset, |
||
| 24 | build_log_dir, |
||
| 25 | calculate_metrics, |
||
| 26 | save_array, |
||
| 27 | save_metric_dict, |
||
| 28 | ) |
||
| 29 | |||
| 30 | logger = log.get(__name__) |
||
| 31 | |||
| 32 | |||
| 33 | def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]: |
||
| 34 | """ |
||
| 35 | Create directory for saving the paired data |
||
| 36 | |||
| 37 | :param indices: indices of the pair, the last one is for label |
||
| 38 | :param save_dir: directory of output |
||
| 39 | :return: - save_dir, str, directory for saving the moving/fixed image |
||
| 40 | - label_dir, str, directory for saving the rest outputs |
||
| 41 | """ |
||
| 42 | |||
| 43 | # cast indices to string and init directory name |
||
| 44 | pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]]) |
||
| 45 | pair_dir = os.path.join(save_dir, pair_index) |
||
| 46 | os.makedirs(pair_dir, exist_ok=True) |
||
| 47 | |||
| 48 | if indices[-1] >= 0: |
||
| 49 | label_index = f"label_{indices[-1]}" |
||
| 50 | label_dir = os.path.join(pair_dir, label_index) |
||
| 51 | os.makedirs(label_dir, exist_ok=True) |
||
| 52 | else: |
||
| 53 | label_dir = pair_dir |
||
| 54 | |||
| 55 | return pair_dir, label_dir |
||
| 56 | |||
| 57 | |||
| 58 | def predict_on_dataset( |
||
| 59 | dataset: tf.data.Dataset, |
||
| 60 | fixed_grid_ref: tf.Tensor, |
||
| 61 | model: tf.keras.Model, |
||
| 62 | save_dir: str, |
||
| 63 | save_nifti: bool, |
||
| 64 | save_png: bool, |
||
| 65 | ): |
||
| 66 | """ |
||
| 67 | Function to predict results from a dataset from some model |
||
| 68 | |||
| 69 | :param dataset: where data is stored |
||
| 70 | :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) |
||
| 71 | :param model: model to be used for prediction |
||
| 72 | :param save_dir: path to store dir |
||
| 73 | :param save_nifti: if true, outputs will be saved in nifti format |
||
| 74 | :param save_png: if true, outputs will be saved in png format |
||
| 75 | """ |
||
| 76 | # remove the save_dir in case it exists |
||
| 77 | if os.path.exists(save_dir): |
||
| 78 | shutil.rmtree(save_dir) # pragma: no cover |
||
| 79 | |||
| 80 | sample_index_strs = [] |
||
| 81 | metric_lists = [] |
||
| 82 | for _, inputs in enumerate(dataset): |
||
| 83 | batch_size = inputs[list(inputs.keys())[0]].shape[0] |
||
| 84 | outputs = model.predict(x=inputs, batch_size=batch_size) |
||
| 85 | indices, processed = model.postprocess(inputs=inputs, outputs=outputs) |
||
| 86 | |||
| 87 | # convert to np arrays |
||
| 88 | indices = indices.numpy() |
||
| 89 | processed = { |
||
| 90 | k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2]) |
||
| 91 | for k, v in processed.items() |
||
| 92 | } |
||
| 93 | |||
| 94 | # save images of inputs and outputs |
||
| 95 | for sample_index in range(batch_size): |
||
| 96 | # save label independent tensors under pair_dir, otherwise under label_dir |
||
| 97 | |||
| 98 | # init output path |
||
| 99 | indices_i = indices[sample_index, :].astype(int).tolist() |
||
| 100 | pair_dir, label_dir = build_pair_output_path( |
||
| 101 | indices=indices_i, save_dir=save_dir |
||
| 102 | ) |
||
| 103 | |||
| 104 | for name, (arr, normalize, on_label) in processed.items(): |
||
| 105 | if name == "theta": |
||
| 106 | np.savetxt( |
||
| 107 | fname=os.path.join(pair_dir, "affine.txt"), |
||
| 108 | X=arr[sample_index, :, :], |
||
| 109 | delimiter=",", |
||
| 110 | ) |
||
| 111 | continue |
||
| 112 | |||
| 113 | arr_save_dir = label_dir if on_label else pair_dir |
||
| 114 | save_array( |
||
| 115 | save_dir=arr_save_dir, |
||
| 116 | arr=arr[sample_index, :, :, :], |
||
| 117 | name=name, |
||
| 118 | normalize=normalize, # label's value is already in [0, 1] |
||
| 119 | save_nifti=save_nifti, |
||
| 120 | save_png=save_png, |
||
| 121 | overwrite=arr_save_dir == label_dir, |
||
| 122 | ) |
||
| 123 | |||
| 124 | # calculate metric |
||
| 125 | sample_index_str = "_".join([str(x) for x in indices_i]) |
||
| 126 | if sample_index_str in sample_index_strs: # pragma: no cover |
||
| 127 | raise ValueError( |
||
| 128 | "Sample is repeated, maybe the dataset has been repeated." |
||
| 129 | ) |
||
| 130 | sample_index_strs.append(sample_index_str) |
||
| 131 | |||
| 132 | metric = calculate_metrics( |
||
| 133 | fixed_image=processed["fixed_image"][0], |
||
| 134 | fixed_label=processed["fixed_label"][0] if model.labeled else None, |
||
| 135 | pred_fixed_image=processed["pred_fixed_image"][0], |
||
| 136 | pred_fixed_label=processed["pred_fixed_label"][0] |
||
| 137 | if model.labeled |
||
| 138 | else None, |
||
| 139 | fixed_grid_ref=fixed_grid_ref, |
||
| 140 | sample_index=sample_index, |
||
| 141 | ) |
||
| 142 | metric["pair_index"] = indices_i[:-1] |
||
| 143 | metric["label_index"] = indices_i[-1] |
||
| 144 | metric_lists.append(metric) |
||
| 145 | |||
| 146 | # save metric |
||
| 147 | save_metric_dict(save_dir=save_dir, metrics=metric_lists) |
||
| 148 | |||
| 149 | |||
| 150 | def build_config( |
||
| 151 | config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str |
||
| 152 | ) -> Tuple[Dict, str, str]: |
||
| 153 | """ |
||
| 154 | Function to create new directory to log directory to store results. |
||
| 155 | |||
| 156 | :param config_path: path of configuration files. |
||
| 157 | :param log_dir: path of the log directory. |
||
| 158 | :param exp_name: experiment name. |
||
| 159 | :param ckpt_path: path where model is stored. |
||
| 160 | :return: - config, configuration dictionary. |
||
| 161 | - exp_name, path of the directory for saving outputs. |
||
| 162 | """ |
||
| 163 | |||
| 164 | # init log directory |
||
| 165 | log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name) |
||
| 166 | |||
| 167 | # replace the ~ with user home path |
||
| 168 | ckpt_path = os.path.expanduser(ckpt_path) |
||
| 169 | |||
| 170 | # load config |
||
| 171 | if config_path == "": |
||
| 172 | # use default config, which should be provided in the log folder |
||
| 173 | config = config_parser.load_configs( |
||
| 174 | "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml" |
||
| 175 | ) |
||
| 176 | else: |
||
| 177 | # use customized config |
||
| 178 | logger.warning( |
||
| 179 | "Using customized configuration. " |
||
| 180 | "The code might break if the config doesn't match the saved model." |
||
| 181 | ) |
||
| 182 | config = config_parser.load_configs(config_path) |
||
| 183 | return config, log_dir, ckpt_path |
||
| 184 | |||
| 185 | |||
| 186 | def predict( |
||
| 187 | gpu: str, |
||
| 188 | ckpt_path: str, |
||
| 189 | split: str, |
||
| 190 | batch_size: int, |
||
| 191 | exp_name: str, |
||
| 192 | config_path: Union[str, List[str]], |
||
| 193 | num_workers: int = 1, |
||
| 194 | gpu_allow_growth: bool = True, |
||
| 195 | save_nifti: bool = True, |
||
| 196 | save_png: bool = True, |
||
| 197 | log_dir: str = "logs", |
||
| 198 | ): |
||
| 199 | """ |
||
| 200 | Function to predict some metrics from the saved model and logging results. |
||
| 201 | |||
| 202 | :param gpu: which env gpu to use. |
||
| 203 | :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x. |
||
| 204 | :param split: train / valid / test, to define the split to be evaluated. |
||
| 205 | :param batch_size: int, batch size to perform predictions. |
||
| 206 | :param exp_name: name of the experiment. |
||
| 207 | :param config_path: to overwrite the default config. |
||
| 208 | :param num_workers: number of cpu cores to be used, <=0 means not limited. |
||
| 209 | :param gpu_allow_growth: whether to allocate whole GPU memory for training. |
||
| 210 | :param save_nifti: if true, outputs will be saved in nifti format. |
||
| 211 | :param save_png: if true, outputs will be saved in png format. |
||
| 212 | :param log_dir: path of the log directory. |
||
| 213 | """ |
||
| 214 | |||
| 215 | # env vars |
||
| 216 | if gpu is not None: |
||
| 217 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
||
| 218 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = ( |
||
| 219 | "false" if gpu_allow_growth else "true" |
||
| 220 | ) |
||
| 221 | View Code Duplication | if num_workers <= 0: # pragma: no cover |
|
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
| 222 | logger.info( |
||
| 223 | "Limiting CPU usage by setting environment variables " |
||
| 224 | "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. " |
||
| 225 | "This may slow down the prediction. " |
||
| 226 | "Please use --num_workers flag to modify the behavior. " |
||
| 227 | "Setting to 0 or negative values will remove the limitation.", |
||
| 228 | num_workers, |
||
| 229 | ) |
||
| 230 | # limit CPU usage |
||
| 231 | # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232 |
||
| 232 | os.environ["OMP_NUM_THREADS"] = str(num_workers) |
||
| 233 | os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers) |
||
| 234 | os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers) |
||
| 235 | |||
| 236 | # load config |
||
| 237 | config, log_dir, ckpt_path = build_config( |
||
| 238 | config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path |
||
| 239 | ) |
||
| 240 | config["train"]["preprocess"]["batch_size"] = batch_size |
||
| 241 | |||
| 242 | # data |
||
| 243 | data_loader, dataset, _ = build_dataset( |
||
| 244 | dataset_config=config["dataset"], |
||
| 245 | preprocess_config=config["train"]["preprocess"], |
||
| 246 | split=split, |
||
| 247 | training=False, |
||
| 248 | repeat=False, |
||
| 249 | ) |
||
| 250 | assert data_loader is not None |
||
| 251 | |||
| 252 | # use strategy to support multiple GPUs |
||
| 253 | # the network is mirrored in each GPU so that we can use larger batch size |
||
| 254 | # https://www.tensorflow.org/guide/distributed_training |
||
| 255 | # only model, optimizer and metrics need to be defined inside the strategy |
||
| 256 | num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) |
||
| 257 | if num_devices > 1: # pragma: no cover |
||
| 258 | strategy = tf.distribute.MirroredStrategy() |
||
| 259 | if batch_size % num_devices != 0: |
||
| 260 | raise ValueError( |
||
| 261 | f"batch size {batch_size} can not be divided evenly " |
||
| 262 | f"by the number of devices." |
||
| 263 | ) |
||
| 264 | else: |
||
| 265 | strategy = tf.distribute.get_strategy() |
||
| 266 | with strategy.scope(): |
||
| 267 | model: tf.keras.Model = REGISTRY.build_model( |
||
| 268 | config=dict( |
||
| 269 | name=config["train"]["method"], |
||
| 270 | moving_image_size=data_loader.moving_image_shape, |
||
| 271 | fixed_image_size=data_loader.fixed_image_shape, |
||
| 272 | index_size=data_loader.num_indices, |
||
| 273 | labeled=config["dataset"][split]["labeled"], |
||
| 274 | batch_size=batch_size, |
||
| 275 | config=config["train"], |
||
| 276 | ) |
||
| 277 | ) |
||
| 278 | optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) |
||
| 279 | model.compile(optimizer=optimizer) |
||
| 280 | model.plot_model(output_dir=log_dir) |
||
| 281 | |||
| 282 | # load weights |
||
| 283 | if ckpt_path.endswith(".ckpt"): |
||
| 284 | # for ckpt from tf.keras.callbacks.ModelCheckpoint |
||
| 285 | # skip warnings because of optimizers |
||
| 286 | # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object |
||
| 287 | model.load_weights(ckpt_path).expect_partial() # pragma: no cover |
||
| 288 | else: |
||
| 289 | # for ckpts from ckpt manager callback |
||
| 290 | _, _ = build_checkpoint_callback( |
||
| 291 | model=model, |
||
| 292 | dataset=dataset, |
||
| 293 | log_dir=log_dir, |
||
| 294 | save_period=config["train"]["save_period"], |
||
| 295 | ckpt_path=ckpt_path, |
||
| 296 | ) |
||
| 297 | |||
| 298 | # predict |
||
| 299 | fixed_grid_ref = tf.expand_dims( |
||
| 300 | layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0 |
||
| 301 | ) # shape = (1, f_dim1, f_dim2, f_dim3, 3) |
||
| 302 | predict_on_dataset( |
||
| 303 | dataset=dataset, |
||
| 304 | fixed_grid_ref=fixed_grid_ref, |
||
| 305 | model=model, |
||
| 306 | save_dir=os.path.join(log_dir, "test"), |
||
| 307 | save_nifti=save_nifti, |
||
| 308 | save_png=save_png, |
||
| 309 | ) |
||
| 310 | |||
| 311 | # close the opened files in data loaders |
||
| 312 | data_loader.close() |
||
| 313 | |||
| 314 | |||
| 315 | def main(args=None): |
||
| 316 | """ |
||
| 317 | Entry point for predict script. |
||
| 318 | |||
| 319 | :param args: |
||
| 320 | """ |
||
| 321 | parser = argparse.ArgumentParser() |
||
| 322 | |||
| 323 | parser.add_argument( |
||
| 324 | "--gpu", |
||
| 325 | "-g", |
||
| 326 | help="GPU index for training." |
||
| 327 | "-g for using GPU remotely" |
||
| 328 | '-g "" for using CPU' |
||
| 329 | '-g "0" for using GPU 0' |
||
| 330 | '-g "0,1" for using GPU 0 and 1.', |
||
| 331 | type=str, |
||
| 332 | required=False, |
||
| 333 | ) |
||
| 334 | |||
| 335 | parser.add_argument( |
||
| 336 | "--gpu_allow_growth", |
||
| 337 | "-gr", |
||
| 338 | help="Prevent TensorFlow from reserving all available GPU memory", |
||
| 339 | default=False, |
||
| 340 | ) |
||
| 341 | |||
| 342 | parser.add_argument( |
||
| 343 | "--num_workers", |
||
| 344 | help="Number of CPUs to be used, <= 0 means unlimited.", |
||
| 345 | type=int, |
||
| 346 | default=1, |
||
| 347 | ) |
||
| 348 | |||
| 349 | parser.add_argument( |
||
| 350 | "--ckpt_path", |
||
| 351 | "-k", |
||
| 352 | help="Path of checkpointed model to load", |
||
| 353 | default="", |
||
| 354 | type=str, |
||
| 355 | required=True, |
||
| 356 | ) |
||
| 357 | |||
| 358 | parser.add_argument( |
||
| 359 | "--split", |
||
| 360 | help="Define the split of data to be used for prediction: " |
||
| 361 | "train or valid or test", |
||
| 362 | type=str, |
||
| 363 | required=True, |
||
| 364 | ) |
||
| 365 | |||
| 366 | parser.add_argument( |
||
| 367 | "--batch_size", "-b", help="Batch size for predictions", default=1, type=int |
||
| 368 | ) |
||
| 369 | |||
| 370 | parser.add_argument( |
||
| 371 | "--log_dir", help="Path of log directory.", default="logs", type=str |
||
| 372 | ) |
||
| 373 | |||
| 374 | parser.add_argument( |
||
| 375 | "--exp_name", "-n", help="Name of the experiment.", default="", type=str |
||
| 376 | ) |
||
| 377 | |||
| 378 | parser.add_argument("--save_nifti", dest="nifti", action="store_true") |
||
| 379 | parser.add_argument("--no_nifti", dest="nifti", action="store_false") |
||
| 380 | parser.set_defaults(nifti=True) |
||
| 381 | |||
| 382 | parser.add_argument("--save_png", dest="png", action="store_true") |
||
| 383 | parser.add_argument("--no_png", dest="png", action="store_false") |
||
| 384 | parser.set_defaults(png=False) |
||
| 385 | |||
| 386 | parser.add_argument( |
||
| 387 | "--config_path", |
||
| 388 | "-c", |
||
| 389 | help="Path of config, must end with .yaml. Can pass multiple paths.", |
||
| 390 | type=str, |
||
| 391 | nargs="*", |
||
| 392 | default="", |
||
| 393 | ) |
||
| 394 | |||
| 395 | args = parser.parse_args(args) |
||
| 396 | |||
| 397 | predict( |
||
| 398 | gpu=args.gpu, |
||
| 399 | ckpt_path=args.ckpt_path, |
||
| 400 | num_workers=args.num_workers, |
||
| 401 | gpu_allow_growth=args.gpu_allow_growth, |
||
| 402 | split=args.split, |
||
| 403 | batch_size=args.batch_size, |
||
| 404 | log_dir=args.log_dir, |
||
| 405 | exp_name=args.exp_name, |
||
| 406 | config_path=args.config_path, |
||
| 407 | save_nifti=args.nifti, |
||
| 408 | save_png=args.png, |
||
| 409 | ) |
||
| 410 | |||
| 411 | |||
| 412 | if __name__ == "__main__": |
||
| 413 | main() # pragma: no cover |
||
| 414 |