Completed
Push — main ( 6b1f4e...d3edf2 )
by Yunguan
30s queued 14s
created

deepreg.predict   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 415
Duplicated Lines 3.13 %

Importance

Changes 0
Metric Value
wmc 23
eloc 236
dl 13
loc 415
rs 10
c 0
b 0
f 0

5 Functions

Rating   Name   Duplication   Size   Complexity  
A build_config() 0 34 2
A build_pair_output_path() 0 23 2
C predict_on_dataset() 0 92 11
C predict() 13 125 7
B main() 0 95 1

How to fix   Duplicated Code   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

1
# coding=utf-8
2
3
"""
4
Module to perform predictions on data using
5
command line interface.
6
"""
7
8
import argparse
9
import os
10
import shutil
11
from typing import Dict, List, Tuple, Union
12
13
import numpy as np
14
import tensorflow as tf
15
16
import deepreg.config.parser as config_parser
17
import deepreg.model.layer_util as layer_util
18
import deepreg.model.optimizer as opt
19
from deepreg import log
20
from deepreg.callback import build_checkpoint_callback
21
from deepreg.registry import REGISTRY
22
from deepreg.util import (
23
    build_dataset,
24
    build_log_dir,
25
    calculate_metrics,
26
    save_array,
27
    save_metric_dict,
28
)
29
30
logger = log.get(__name__)
31
32
33
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]:
34
    """
35
    Create directory for saving the paired data
36
37
    :param indices: indices of the pair, the last one is for label
38
    :param save_dir: directory of output
39
    :return: - save_dir, str, directory for saving the moving/fixed image
40
             - label_dir, str, directory for saving the rest outputs
41
    """
42
43
    # cast indices to string and init directory name
44
    pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]])
45
    pair_dir = os.path.join(save_dir, pair_index)
46
    os.makedirs(pair_dir, exist_ok=True)
47
48
    if indices[-1] >= 0:
49
        label_index = f"label_{indices[-1]}"
50
        label_dir = os.path.join(pair_dir, label_index)
51
        os.makedirs(label_dir, exist_ok=True)
52
    else:
53
        label_dir = pair_dir
54
55
    return pair_dir, label_dir
56
57
58
def predict_on_dataset(
59
    dataset: tf.data.Dataset,
60
    fixed_grid_ref: tf.Tensor,
61
    model: tf.keras.Model,
62
    model_method: str,
63
    save_dir: str,
64
    save_nifti: bool,
65
    save_png: bool,
66
):
67
    """
68
    Function to predict results from a dataset from some model
69
70
    :param dataset: where data is stored
71
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
72
    :param model: model to be used for prediction
73
    :param model_method: ddf / dvf / affine / conditional
74
    :param save_dir: path to store dir
75
    :param save_nifti: if true, outputs will be saved in nifti format
76
    :param save_png: if true, outputs will be saved in png format
77
    """
78
    # remove the save_dir in case it exists
79
    if os.path.exists(save_dir):
80
        shutil.rmtree(save_dir)  # pragma: no cover
81
82
    sample_index_strs = []
83
    metric_lists = []
84
    for _, inputs in enumerate(dataset):
85
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
86
        outputs = model.predict(x=inputs, batch_size=batch_size)
87
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)
88
89
        # convert to np arrays
90
        indices = indices.numpy()
91
        processed = {
92
            k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
93
            for k, v in processed.items()
94
        }
95
96
        # save images of inputs and outputs
97
        for sample_index in range(batch_size):
98
            # save label independent tensors under pair_dir, otherwise under label_dir
99
100
            # init output path
101
            indices_i = indices[sample_index, :].astype(int).tolist()
102
            pair_dir, label_dir = build_pair_output_path(
103
                indices=indices_i, save_dir=save_dir
104
            )
105
106
            for name, (arr, normalize, on_label) in processed.items():
107
                if name == "theta":
108
                    np.savetxt(
109
                        fname=os.path.join(pair_dir, "affine.txt"),
110
                        X=arr[sample_index, :, :],
111
                        delimiter=",",
112
                    )
113
                    continue
114
115
                arr_save_dir = label_dir if on_label else pair_dir
116
                save_array(
117
                    save_dir=arr_save_dir,
118
                    arr=arr[sample_index, :, :, :],
119
                    name=name,
120
                    normalize=normalize,  # label's value is already in [0, 1]
121
                    save_nifti=save_nifti,
122
                    save_png=save_png,
123
                    overwrite=arr_save_dir == label_dir,
124
                )
125
126
            # calculate metric
127
            sample_index_str = "_".join([str(x) for x in indices_i])
128
            if sample_index_str in sample_index_strs:  # pragma: no cover
129
                raise ValueError(
130
                    "Sample is repeated, maybe the dataset has been repeated."
131
                )
132
            sample_index_strs.append(sample_index_str)
133
134
            metric = calculate_metrics(
135
                fixed_image=processed["fixed_image"][0],
136
                fixed_label=processed["fixed_label"][0] if model.labeled else None,
137
                pred_fixed_image=processed["pred_fixed_image"][0],
138
                pred_fixed_label=processed["pred_fixed_label"][0]
139
                if model.labeled
140
                else None,
141
                fixed_grid_ref=fixed_grid_ref,
142
                sample_index=sample_index,
143
            )
144
            metric["pair_index"] = indices_i[:-1]
145
            metric["label_index"] = indices_i[-1]
146
            metric_lists.append(metric)
147
148
    # save metric
149
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
150
151
152
def build_config(
153
    config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str
154
) -> Tuple[Dict, str, str]:
155
    """
156
    Function to create new directory to log directory to store results.
157
158
    :param config_path: path of configuration files.
159
    :param log_dir: path of the log directory.
160
    :param exp_name: experiment name.
161
    :param ckpt_path: path where model is stored.
162
    :return: - config, configuration dictionary.
163
             - exp_name, path of the directory for saving outputs.
164
    """
165
166
    # init log directory
167
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
168
169
    # replace the ~ with user home path
170
    ckpt_path = os.path.expanduser(ckpt_path)
171
172
    # load config
173
    if config_path == "":
174
        # use default config, which should be provided in the log folder
175
        config = config_parser.load_configs(
176
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml"
177
        )
178
    else:
179
        # use customized config
180
        logger.warning(
181
            "Using customized configuration. "
182
            "The code might break if the config doesn't match the saved model."
183
        )
184
        config = config_parser.load_configs(config_path)
185
    return config, log_dir, ckpt_path
186
187
188
def predict(
189
    gpu: str,
190
    ckpt_path: str,
191
    mode: str,
192
    batch_size: int,
193
    exp_name: str,
194
    config_path: Union[str, List[str]],
195
    num_workers: int = 1,
196
    gpu_allow_growth: bool = True,
197
    save_nifti: bool = True,
198
    save_png: bool = True,
199
    log_dir: str = "logs",
200
):
201
    """
202
    Function to predict some metrics from the saved model and logging results.
203
204
    :param gpu: which env gpu to use.
205
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
206
    :param mode: train / valid / test, to define which split of dataset to be evaluated.
207
    :param batch_size: int, batch size to perform predictions.
208
    :param exp_name: name of the experiment.
209
    :param config_path: to overwrite the default config.
210
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
211
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
212
    :param save_nifti: if true, outputs will be saved in nifti format.
213
    :param save_png: if true, outputs will be saved in png format.
214
    :param log_dir: path of the log directory.
215
    """
216
217
    # env vars
218
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
219
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"
220 View Code Duplication
    if num_workers <= 0:  # pragma: no cover
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
221
        logger.info(
222
            "Limiting CPU usage by setting environment variables "
223
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
224
            "This may slow down the prediction. "
225
            "Please use --num_workers flag to modify the behavior. "
226
            "Setting to 0 or negative values will remove the limitation.",
227
            num_workers,
228
        )
229
        # limit CPU usage
230
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
231
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
232
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
233
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)
234
235
    # load config
236
    config, log_dir, ckpt_path = build_config(
237
        config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path
238
    )
239
    config["train"]["preprocess"]["batch_size"] = batch_size
240
241
    # data
242
    data_loader, dataset, _ = build_dataset(
243
        dataset_config=config["dataset"],
244
        preprocess_config=config["train"]["preprocess"],
245
        mode=mode,
246
        training=False,
247
        repeat=False,
248
    )
249
    assert data_loader is not None
250
251
    # use strategy to support multiple GPUs
252
    # the network is mirrored in each GPU so that we can use larger batch size
253
    # https://www.tensorflow.org/guide/distributed_training
254
    # only model, optimizer and metrics need to be defined inside the strategy
255
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
256
    if num_devices > 1:  # pragma: no cover
257
        strategy = tf.distribute.MirroredStrategy()
258
        if batch_size % num_devices != 0:
259
            raise ValueError(
260
                f"batch size {batch_size} can not be divided evenly "
261
                f"by the number of devices."
262
            )
263
    else:
264
        strategy = tf.distribute.get_strategy()
265
    with strategy.scope():
266
        model: tf.keras.Model = REGISTRY.build_model(
267
            config=dict(
268
                name=config["train"]["method"],
269
                moving_image_size=data_loader.moving_image_shape,
270
                fixed_image_size=data_loader.fixed_image_shape,
271
                index_size=data_loader.num_indices,
272
                labeled=config["dataset"]["labeled"],
273
                batch_size=batch_size,
274
                config=config["train"],
275
            )
276
        )
277
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
278
        model.compile(optimizer=optimizer)
279
        model.plot_model(output_dir=log_dir)
280
281
    # load weights
282
    if ckpt_path.endswith(".ckpt"):
283
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
284
        # skip warnings because of optimizers
285
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
286
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
287
    else:
288
        # for ckpts from ckpt manager callback
289
        _, _ = build_checkpoint_callback(
290
            model=model,
291
            dataset=dataset,
292
            log_dir=log_dir,
293
            save_period=config["train"]["save_period"],
294
            ckpt_path=ckpt_path,
295
        )
296
297
    # predict
298
    fixed_grid_ref = tf.expand_dims(
299
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0
300
    )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
301
    predict_on_dataset(
302
        dataset=dataset,
303
        fixed_grid_ref=fixed_grid_ref,
304
        model=model,
305
        model_method=config["train"]["method"],
306
        save_dir=os.path.join(log_dir, "test"),
307
        save_nifti=save_nifti,
308
        save_png=save_png,
309
    )
310
311
    # close the opened files in data loaders
312
    data_loader.close()
313
314
315
def main(args=None):
316
    """
317
    Entry point for predict script.
318
319
    :param args:
320
    """
321
    parser = argparse.ArgumentParser()
322
323
    parser.add_argument(
324
        "--gpu",
325
        "-g",
326
        help="GPU index for training."
327
        '-g "" for using CPU'
328
        '-g "0" for using GPU 0'
329
        '-g "0,1" for using GPU 0 and 1.',
330
        type=str,
331
        required=True,
332
    )
333
334
    parser.add_argument(
335
        "--gpu_allow_growth",
336
        "-gr",
337
        help="Prevent TensorFlow from reserving all available GPU memory",
338
        default=False,
339
    )
340
341
    parser.add_argument(
342
        "--num_workers",
343
        help="Number of CPUs to be used, <= 0 means unlimited.",
344
        type=int,
345
        default=1,
346
    )
347
348
    parser.add_argument(
349
        "--ckpt_path",
350
        "-k",
351
        help="Path of checkpointed model to load",
352
        default="",
353
        type=str,
354
        required=True,
355
    )
356
357
    parser.add_argument(
358
        "--mode",
359
        "-m",
360
        help="Define the split of data to be used for prediction."
361
        "train or valid or test",
362
        type=str,
363
        default="test",
364
        required=True,
365
    )
366
367
    parser.add_argument(
368
        "--batch_size", "-b", help="Batch size for predictions", default=1, type=int
369
    )
370
371
    parser.add_argument(
372
        "--log_dir", help="Path of log directory.", default="logs", type=str
373
    )
374
375
    parser.add_argument(
376
        "--exp_name", "-n", help="Name of the experiment.", default="", type=str
377
    )
378
379
    parser.add_argument("--save_nifti", dest="nifti", action="store_true")
380
    parser.add_argument("--no_nifti", dest="nifti", action="store_false")
381
    parser.set_defaults(nifti=True)
382
383
    parser.add_argument("--save_png", dest="png", action="store_true")
384
    parser.add_argument("--no_png", dest="png", action="store_false")
385
    parser.set_defaults(png=False)
386
387
    parser.add_argument(
388
        "--config_path",
389
        "-c",
390
        help="Path of config, must end with .yaml. Can pass multiple paths.",
391
        type=str,
392
        nargs="*",
393
        default="",
394
    )
395
396
    args = parser.parse_args(args)
397
398
    predict(
399
        gpu=args.gpu,
400
        ckpt_path=args.ckpt_path,
401
        num_workers=args.num_workers,
402
        gpu_allow_growth=args.gpu_allow_growth,
403
        mode=args.mode,
404
        batch_size=args.batch_size,
405
        log_dir=args.log_dir,
406
        exp_name=args.exp_name,
407
        config_path=args.config_path,
408
        save_nifti=args.nifti,
409
        save_png=args.png,
410
    )
411
412
413
if __name__ == "__main__":
414
    main()  # pragma: no cover
415