1 | # coding=utf-8 |
||
2 | |||
3 | """ |
||
4 | Module to perform predictions on data using |
||
5 | command line interface. |
||
6 | """ |
||
7 | |||
8 | import argparse |
||
9 | import os |
||
10 | import shutil |
||
11 | from typing import Dict, List, Tuple, Union |
||
12 | |||
13 | import numpy as np |
||
14 | import tensorflow as tf |
||
15 | |||
16 | import deepreg.config.parser as config_parser |
||
17 | import deepreg.model.layer_util as layer_util |
||
18 | import deepreg.model.optimizer as opt |
||
19 | from deepreg import log |
||
20 | from deepreg.callback import build_checkpoint_callback |
||
21 | from deepreg.registry import REGISTRY |
||
22 | from deepreg.util import ( |
||
23 | build_dataset, |
||
24 | build_log_dir, |
||
25 | calculate_metrics, |
||
26 | save_array, |
||
27 | save_metric_dict, |
||
28 | ) |
||
29 | |||
30 | logger = log.get(__name__) |
||
31 | |||
32 | |||
33 | def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]: |
||
34 | """ |
||
35 | Create directory for saving the paired data |
||
36 | |||
37 | :param indices: indices of the pair, the last one is for label |
||
38 | :param save_dir: directory of output |
||
39 | :return: - save_dir, str, directory for saving the moving/fixed image |
||
40 | - label_dir, str, directory for saving the rest outputs |
||
41 | """ |
||
42 | |||
43 | # cast indices to string and init directory name |
||
44 | pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]]) |
||
45 | pair_dir = os.path.join(save_dir, pair_index) |
||
46 | os.makedirs(pair_dir, exist_ok=True) |
||
47 | |||
48 | if indices[-1] >= 0: |
||
49 | label_index = f"label_{indices[-1]}" |
||
50 | label_dir = os.path.join(pair_dir, label_index) |
||
51 | os.makedirs(label_dir, exist_ok=True) |
||
52 | else: |
||
53 | label_dir = pair_dir |
||
54 | |||
55 | return pair_dir, label_dir |
||
56 | |||
57 | |||
58 | def predict_on_dataset( |
||
59 | dataset: tf.data.Dataset, |
||
60 | fixed_grid_ref: tf.Tensor, |
||
61 | model: tf.keras.Model, |
||
62 | save_dir: str, |
||
63 | save_nifti: bool, |
||
64 | save_png: bool, |
||
65 | ): |
||
66 | """ |
||
67 | Function to predict results from a dataset from some model |
||
68 | |||
69 | :param dataset: where data is stored |
||
70 | :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3) |
||
71 | :param model: model to be used for prediction |
||
72 | :param save_dir: path to store dir |
||
73 | :param save_nifti: if true, outputs will be saved in nifti format |
||
74 | :param save_png: if true, outputs will be saved in png format |
||
75 | """ |
||
76 | # remove the save_dir in case it exists |
||
77 | if os.path.exists(save_dir): |
||
78 | shutil.rmtree(save_dir) # pragma: no cover |
||
79 | |||
80 | sample_index_strs = [] |
||
81 | metric_lists = [] |
||
82 | for _, inputs in enumerate(dataset): |
||
83 | batch_size = inputs[list(inputs.keys())[0]].shape[0] |
||
84 | outputs = model.predict(x=inputs, batch_size=batch_size) |
||
85 | indices, processed = model.postprocess(inputs=inputs, outputs=outputs) |
||
86 | |||
87 | # convert to np arrays |
||
88 | indices = indices.numpy() |
||
89 | processed = { |
||
90 | k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2]) |
||
91 | for k, v in processed.items() |
||
92 | } |
||
93 | |||
94 | # save images of inputs and outputs |
||
95 | for sample_index in range(batch_size): |
||
96 | # save label independent tensors under pair_dir, otherwise under label_dir |
||
97 | |||
98 | # init output path |
||
99 | indices_i = indices[sample_index, :].astype(int).tolist() |
||
100 | pair_dir, label_dir = build_pair_output_path( |
||
101 | indices=indices_i, save_dir=save_dir |
||
102 | ) |
||
103 | |||
104 | for name, (arr, normalize, on_label) in processed.items(): |
||
105 | if name == "theta": |
||
106 | np.savetxt( |
||
107 | fname=os.path.join(pair_dir, "affine.txt"), |
||
108 | X=arr[sample_index, :, :], |
||
109 | delimiter=",", |
||
110 | ) |
||
111 | continue |
||
112 | |||
113 | arr_save_dir = label_dir if on_label else pair_dir |
||
114 | save_array( |
||
115 | save_dir=arr_save_dir, |
||
116 | arr=arr[sample_index, :, :, :], |
||
117 | name=name, |
||
118 | normalize=normalize, # label's value is already in [0, 1] |
||
119 | save_nifti=save_nifti, |
||
120 | save_png=save_png, |
||
121 | overwrite=arr_save_dir == label_dir, |
||
122 | ) |
||
123 | |||
124 | # calculate metric |
||
125 | sample_index_str = "_".join([str(x) for x in indices_i]) |
||
126 | if sample_index_str in sample_index_strs: # pragma: no cover |
||
127 | raise ValueError( |
||
128 | "Sample is repeated, maybe the dataset has been repeated." |
||
129 | ) |
||
130 | sample_index_strs.append(sample_index_str) |
||
131 | |||
132 | metric = calculate_metrics( |
||
133 | fixed_image=processed["fixed_image"][0], |
||
134 | fixed_label=processed["fixed_label"][0] if model.labeled else None, |
||
135 | pred_fixed_image=processed["pred_fixed_image"][0], |
||
136 | pred_fixed_label=processed["pred_fixed_label"][0] |
||
137 | if model.labeled |
||
138 | else None, |
||
139 | fixed_grid_ref=fixed_grid_ref, |
||
140 | sample_index=sample_index, |
||
141 | ) |
||
142 | metric["pair_index"] = indices_i[:-1] |
||
143 | metric["label_index"] = indices_i[-1] |
||
144 | metric_lists.append(metric) |
||
145 | |||
146 | # save metric |
||
147 | save_metric_dict(save_dir=save_dir, metrics=metric_lists) |
||
148 | |||
149 | |||
150 | def build_config( |
||
151 | config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str |
||
152 | ) -> Tuple[Dict, str, str]: |
||
153 | """ |
||
154 | Function to create new directory to log directory to store results. |
||
155 | |||
156 | :param config_path: path of configuration files. |
||
157 | :param log_dir: path of the log directory. |
||
158 | :param exp_name: experiment name. |
||
159 | :param ckpt_path: path where model is stored. |
||
160 | :return: - config, configuration dictionary. |
||
161 | - exp_name, path of the directory for saving outputs. |
||
162 | """ |
||
163 | |||
164 | # init log directory |
||
165 | log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name) |
||
166 | |||
167 | # replace the ~ with user home path |
||
168 | ckpt_path = os.path.expanduser(ckpt_path) |
||
169 | |||
170 | # load config |
||
171 | if config_path == "": |
||
172 | # use default config, which should be provided in the log folder |
||
173 | config = config_parser.load_configs( |
||
174 | "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml" |
||
175 | ) |
||
176 | else: |
||
177 | # use customized config |
||
178 | logger.warning( |
||
179 | "Using customized configuration. " |
||
180 | "The code might break if the config doesn't match the saved model." |
||
181 | ) |
||
182 | config = config_parser.load_configs(config_path) |
||
183 | return config, log_dir, ckpt_path |
||
184 | |||
185 | |||
186 | def predict( |
||
187 | gpu: str, |
||
188 | ckpt_path: str, |
||
189 | split: str, |
||
190 | batch_size: int, |
||
191 | exp_name: str, |
||
192 | config_path: Union[str, List[str]], |
||
193 | num_workers: int = 1, |
||
194 | gpu_allow_growth: bool = True, |
||
195 | save_nifti: bool = True, |
||
196 | save_png: bool = True, |
||
197 | log_dir: str = "logs", |
||
198 | ): |
||
199 | """ |
||
200 | Function to predict some metrics from the saved model and logging results. |
||
201 | |||
202 | :param gpu: which env gpu to use. |
||
203 | :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x. |
||
204 | :param split: train / valid / test, to define the split to be evaluated. |
||
205 | :param batch_size: int, batch size to perform predictions. |
||
206 | :param exp_name: name of the experiment. |
||
207 | :param config_path: to overwrite the default config. |
||
208 | :param num_workers: number of cpu cores to be used, <=0 means not limited. |
||
209 | :param gpu_allow_growth: whether to allocate whole GPU memory for training. |
||
210 | :param save_nifti: if true, outputs will be saved in nifti format. |
||
211 | :param save_png: if true, outputs will be saved in png format. |
||
212 | :param log_dir: path of the log directory. |
||
213 | """ |
||
214 | |||
215 | # env vars |
||
216 | if gpu is not None: |
||
217 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu |
||
218 | os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = ( |
||
219 | "false" if gpu_allow_growth else "true" |
||
220 | ) |
||
221 | View Code Duplication | if num_workers <= 0: # pragma: no cover |
|
0 ignored issues
–
show
Duplication
introduced
by
![]() |
|||
222 | logger.info( |
||
223 | "Limiting CPU usage by setting environment variables " |
||
224 | "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. " |
||
225 | "This may slow down the prediction. " |
||
226 | "Please use --num_workers flag to modify the behavior. " |
||
227 | "Setting to 0 or negative values will remove the limitation.", |
||
228 | num_workers, |
||
229 | ) |
||
230 | # limit CPU usage |
||
231 | # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232 |
||
232 | os.environ["OMP_NUM_THREADS"] = str(num_workers) |
||
233 | os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers) |
||
234 | os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers) |
||
235 | |||
236 | # load config |
||
237 | config, log_dir, ckpt_path = build_config( |
||
238 | config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path |
||
239 | ) |
||
240 | config["train"]["preprocess"]["batch_size"] = batch_size |
||
241 | |||
242 | # data |
||
243 | data_loader, dataset, _ = build_dataset( |
||
244 | dataset_config=config["dataset"], |
||
245 | preprocess_config=config["train"]["preprocess"], |
||
246 | split=split, |
||
247 | training=False, |
||
248 | repeat=False, |
||
249 | ) |
||
250 | assert data_loader is not None |
||
251 | |||
252 | # use strategy to support multiple GPUs |
||
253 | # the network is mirrored in each GPU so that we can use larger batch size |
||
254 | # https://www.tensorflow.org/guide/distributed_training |
||
255 | # only model, optimizer and metrics need to be defined inside the strategy |
||
256 | num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) |
||
257 | if num_devices > 1: # pragma: no cover |
||
258 | strategy = tf.distribute.MirroredStrategy() |
||
259 | if batch_size % num_devices != 0: |
||
260 | raise ValueError( |
||
261 | f"batch size {batch_size} can not be divided evenly " |
||
262 | f"by the number of devices." |
||
263 | ) |
||
264 | else: |
||
265 | strategy = tf.distribute.get_strategy() |
||
266 | with strategy.scope(): |
||
267 | model: tf.keras.Model = REGISTRY.build_model( |
||
268 | config=dict( |
||
269 | name=config["train"]["method"], |
||
270 | moving_image_size=data_loader.moving_image_shape, |
||
271 | fixed_image_size=data_loader.fixed_image_shape, |
||
272 | index_size=data_loader.num_indices, |
||
273 | labeled=config["dataset"][split]["labeled"], |
||
274 | batch_size=batch_size, |
||
275 | config=config["train"], |
||
276 | ) |
||
277 | ) |
||
278 | optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) |
||
279 | model.compile(optimizer=optimizer) |
||
280 | model.plot_model(output_dir=log_dir) |
||
281 | |||
282 | # load weights |
||
283 | if ckpt_path.endswith(".ckpt"): |
||
284 | # for ckpt from tf.keras.callbacks.ModelCheckpoint |
||
285 | # skip warnings because of optimizers |
||
286 | # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object |
||
287 | model.load_weights(ckpt_path).expect_partial() # pragma: no cover |
||
288 | else: |
||
289 | # for ckpts from ckpt manager callback |
||
290 | _, _ = build_checkpoint_callback( |
||
291 | model=model, |
||
292 | dataset=dataset, |
||
293 | log_dir=log_dir, |
||
294 | save_period=config["train"]["save_period"], |
||
295 | ckpt_path=ckpt_path, |
||
296 | ) |
||
297 | |||
298 | # predict |
||
299 | fixed_grid_ref = tf.expand_dims( |
||
300 | layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0 |
||
301 | ) # shape = (1, f_dim1, f_dim2, f_dim3, 3) |
||
302 | predict_on_dataset( |
||
303 | dataset=dataset, |
||
304 | fixed_grid_ref=fixed_grid_ref, |
||
305 | model=model, |
||
306 | save_dir=os.path.join(log_dir, "test"), |
||
307 | save_nifti=save_nifti, |
||
308 | save_png=save_png, |
||
309 | ) |
||
310 | |||
311 | # close the opened files in data loaders |
||
312 | data_loader.close() |
||
313 | |||
314 | |||
315 | def main(args=None): |
||
316 | """ |
||
317 | Entry point for predict script. |
||
318 | |||
319 | :param args: |
||
320 | """ |
||
321 | parser = argparse.ArgumentParser() |
||
322 | |||
323 | parser.add_argument( |
||
324 | "--gpu", |
||
325 | "-g", |
||
326 | help="GPU index for training." |
||
327 | "-g for using GPU remotely" |
||
328 | '-g "" for using CPU' |
||
329 | '-g "0" for using GPU 0' |
||
330 | '-g "0,1" for using GPU 0 and 1.', |
||
331 | type=str, |
||
332 | required=False, |
||
333 | ) |
||
334 | |||
335 | parser.add_argument( |
||
336 | "--gpu_allow_growth", |
||
337 | "-gr", |
||
338 | help="Prevent TensorFlow from reserving all available GPU memory", |
||
339 | default=False, |
||
340 | ) |
||
341 | |||
342 | parser.add_argument( |
||
343 | "--num_workers", |
||
344 | help="Number of CPUs to be used, <= 0 means unlimited.", |
||
345 | type=int, |
||
346 | default=1, |
||
347 | ) |
||
348 | |||
349 | parser.add_argument( |
||
350 | "--ckpt_path", |
||
351 | "-k", |
||
352 | help="Path of checkpointed model to load", |
||
353 | default="", |
||
354 | type=str, |
||
355 | required=True, |
||
356 | ) |
||
357 | |||
358 | parser.add_argument( |
||
359 | "--split", |
||
360 | help="Define the split of data to be used for prediction: " |
||
361 | "train or valid or test", |
||
362 | type=str, |
||
363 | required=True, |
||
364 | ) |
||
365 | |||
366 | parser.add_argument( |
||
367 | "--batch_size", "-b", help="Batch size for predictions", default=1, type=int |
||
368 | ) |
||
369 | |||
370 | parser.add_argument( |
||
371 | "--log_dir", help="Path of log directory.", default="logs", type=str |
||
372 | ) |
||
373 | |||
374 | parser.add_argument( |
||
375 | "--exp_name", "-n", help="Name of the experiment.", default="", type=str |
||
376 | ) |
||
377 | |||
378 | parser.add_argument("--save_nifti", dest="nifti", action="store_true") |
||
379 | parser.add_argument("--no_nifti", dest="nifti", action="store_false") |
||
380 | parser.set_defaults(nifti=True) |
||
381 | |||
382 | parser.add_argument("--save_png", dest="png", action="store_true") |
||
383 | parser.add_argument("--no_png", dest="png", action="store_false") |
||
384 | parser.set_defaults(png=False) |
||
385 | |||
386 | parser.add_argument( |
||
387 | "--config_path", |
||
388 | "-c", |
||
389 | help="Path of config, must end with .yaml. Can pass multiple paths.", |
||
390 | type=str, |
||
391 | nargs="*", |
||
392 | default="", |
||
393 | ) |
||
394 | |||
395 | args = parser.parse_args(args) |
||
396 | |||
397 | predict( |
||
398 | gpu=args.gpu, |
||
399 | ckpt_path=args.ckpt_path, |
||
400 | num_workers=args.num_workers, |
||
401 | gpu_allow_growth=args.gpu_allow_growth, |
||
402 | split=args.split, |
||
403 | batch_size=args.batch_size, |
||
404 | log_dir=args.log_dir, |
||
405 | exp_name=args.exp_name, |
||
406 | config_path=args.config_path, |
||
407 | save_nifti=args.nifti, |
||
408 | save_png=args.png, |
||
409 | ) |
||
410 | |||
411 | |||
412 | if __name__ == "__main__": |
||
413 | main() # pragma: no cover |
||
414 |