Completed
Push — main ( 6b1f4e...d3edf2 )
by Yunguan
30s queued 14s
created

deepreg.predict.predict()   C

Complexity

Conditions 7

Size

Total Lines 125
Code Lines 70

Duplication

Lines 13
Ratio 10.4 %

Importance

Changes 0
Metric Value
eloc 70
dl 13
loc 125
rs 6.5818
c 0
b 0
f 0
cc 7
nop 11

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
2
3
"""
4
Module to perform predictions on data using
5
command line interface.
6
"""
7
8
import argparse
9
import os
10
import shutil
11
from typing import Dict, List, Tuple, Union
12
13
import numpy as np
14
import tensorflow as tf
15
16
import deepreg.config.parser as config_parser
17
import deepreg.model.layer_util as layer_util
18
import deepreg.model.optimizer as opt
19
from deepreg import log
20
from deepreg.callback import build_checkpoint_callback
21
from deepreg.registry import REGISTRY
22
from deepreg.util import (
23
    build_dataset,
24
    build_log_dir,
25
    calculate_metrics,
26
    save_array,
27
    save_metric_dict,
28
)
29
30
logger = log.get(__name__)
31
32
33
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]:
34
    """
35
    Create directory for saving the paired data
36
37
    :param indices: indices of the pair, the last one is for label
38
    :param save_dir: directory of output
39
    :return: - save_dir, str, directory for saving the moving/fixed image
40
             - label_dir, str, directory for saving the rest outputs
41
    """
42
43
    # cast indices to string and init directory name
44
    pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]])
45
    pair_dir = os.path.join(save_dir, pair_index)
46
    os.makedirs(pair_dir, exist_ok=True)
47
48
    if indices[-1] >= 0:
49
        label_index = f"label_{indices[-1]}"
50
        label_dir = os.path.join(pair_dir, label_index)
51
        os.makedirs(label_dir, exist_ok=True)
52
    else:
53
        label_dir = pair_dir
54
55
    return pair_dir, label_dir
56
57
58
def predict_on_dataset(
59
    dataset: tf.data.Dataset,
60
    fixed_grid_ref: tf.Tensor,
61
    model: tf.keras.Model,
62
    model_method: str,
63
    save_dir: str,
64
    save_nifti: bool,
65
    save_png: bool,
66
):
67
    """
68
    Function to predict results from a dataset from some model
69
70
    :param dataset: where data is stored
71
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
72
    :param model: model to be used for prediction
73
    :param model_method: ddf / dvf / affine / conditional
74
    :param save_dir: path to store dir
75
    :param save_nifti: if true, outputs will be saved in nifti format
76
    :param save_png: if true, outputs will be saved in png format
77
    """
78
    # remove the save_dir in case it exists
79
    if os.path.exists(save_dir):
80
        shutil.rmtree(save_dir)  # pragma: no cover
81
82
    sample_index_strs = []
83
    metric_lists = []
84
    for _, inputs in enumerate(dataset):
85
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
86
        outputs = model.predict(x=inputs, batch_size=batch_size)
87
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)
88
89
        # convert to np arrays
90
        indices = indices.numpy()
91
        processed = {
92
            k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
93
            for k, v in processed.items()
94
        }
95
96
        # save images of inputs and outputs
97
        for sample_index in range(batch_size):
98
            # save label independent tensors under pair_dir, otherwise under label_dir
99
100
            # init output path
101
            indices_i = indices[sample_index, :].astype(int).tolist()
102
            pair_dir, label_dir = build_pair_output_path(
103
                indices=indices_i, save_dir=save_dir
104
            )
105
106
            for name, (arr, normalize, on_label) in processed.items():
107
                if name == "theta":
108
                    np.savetxt(
109
                        fname=os.path.join(pair_dir, "affine.txt"),
110
                        X=arr[sample_index, :, :],
111
                        delimiter=",",
112
                    )
113
                    continue
114
115
                arr_save_dir = label_dir if on_label else pair_dir
116
                save_array(
117
                    save_dir=arr_save_dir,
118
                    arr=arr[sample_index, :, :, :],
119
                    name=name,
120
                    normalize=normalize,  # label's value is already in [0, 1]
121
                    save_nifti=save_nifti,
122
                    save_png=save_png,
123
                    overwrite=arr_save_dir == label_dir,
124
                )
125
126
            # calculate metric
127
            sample_index_str = "_".join([str(x) for x in indices_i])
128
            if sample_index_str in sample_index_strs:  # pragma: no cover
129
                raise ValueError(
130
                    "Sample is repeated, maybe the dataset has been repeated."
131
                )
132
            sample_index_strs.append(sample_index_str)
133
134
            metric = calculate_metrics(
135
                fixed_image=processed["fixed_image"][0],
136
                fixed_label=processed["fixed_label"][0] if model.labeled else None,
137
                pred_fixed_image=processed["pred_fixed_image"][0],
138
                pred_fixed_label=processed["pred_fixed_label"][0]
139
                if model.labeled
140
                else None,
141
                fixed_grid_ref=fixed_grid_ref,
142
                sample_index=sample_index,
143
            )
144
            metric["pair_index"] = indices_i[:-1]
145
            metric["label_index"] = indices_i[-1]
146
            metric_lists.append(metric)
147
148
    # save metric
149
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
150
151
152
def build_config(
153
    config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str
154
) -> Tuple[Dict, str, str]:
155
    """
156
    Function to create new directory to log directory to store results.
157
158
    :param config_path: path of configuration files.
159
    :param log_dir: path of the log directory.
160
    :param exp_name: experiment name.
161
    :param ckpt_path: path where model is stored.
162
    :return: - config, configuration dictionary.
163
             - exp_name, path of the directory for saving outputs.
164
    """
165
166
    # init log directory
167
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
168
169
    # replace the ~ with user home path
170
    ckpt_path = os.path.expanduser(ckpt_path)
171
172
    # load config
173
    if config_path == "":
174
        # use default config, which should be provided in the log folder
175
        config = config_parser.load_configs(
176
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml"
177
        )
178
    else:
179
        # use customized config
180
        logger.warning(
181
            "Using customized configuration. "
182
            "The code might break if the config doesn't match the saved model."
183
        )
184
        config = config_parser.load_configs(config_path)
185
    return config, log_dir, ckpt_path
186
187
188
def predict(
189
    gpu: str,
190
    ckpt_path: str,
191
    mode: str,
192
    batch_size: int,
193
    exp_name: str,
194
    config_path: Union[str, List[str]],
195
    num_workers: int = 1,
196
    gpu_allow_growth: bool = True,
197
    save_nifti: bool = True,
198
    save_png: bool = True,
199
    log_dir: str = "logs",
200
):
201
    """
202
    Function to predict some metrics from the saved model and logging results.
203
204
    :param gpu: which env gpu to use.
205
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
206
    :param mode: train / valid / test, to define which split of dataset to be evaluated.
207
    :param batch_size: int, batch size to perform predictions.
208
    :param exp_name: name of the experiment.
209
    :param config_path: to overwrite the default config.
210
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
211
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
212
    :param save_nifti: if true, outputs will be saved in nifti format.
213
    :param save_png: if true, outputs will be saved in png format.
214
    :param log_dir: path of the log directory.
215
    """
216
217
    # env vars
218
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
219
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"
220 View Code Duplication
    if num_workers <= 0:  # pragma: no cover
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
221
        logger.info(
222
            "Limiting CPU usage by setting environment variables "
223
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
224
            "This may slow down the prediction. "
225
            "Please use --num_workers flag to modify the behavior. "
226
            "Setting to 0 or negative values will remove the limitation.",
227
            num_workers,
228
        )
229
        # limit CPU usage
230
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
231
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
232
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
233
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)
234
235
    # load config
236
    config, log_dir, ckpt_path = build_config(
237
        config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path
238
    )
239
    config["train"]["preprocess"]["batch_size"] = batch_size
240
241
    # data
242
    data_loader, dataset, _ = build_dataset(
243
        dataset_config=config["dataset"],
244
        preprocess_config=config["train"]["preprocess"],
245
        mode=mode,
246
        training=False,
247
        repeat=False,
248
    )
249
    assert data_loader is not None
250
251
    # use strategy to support multiple GPUs
252
    # the network is mirrored in each GPU so that we can use larger batch size
253
    # https://www.tensorflow.org/guide/distributed_training
254
    # only model, optimizer and metrics need to be defined inside the strategy
255
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
256
    if num_devices > 1:  # pragma: no cover
257
        strategy = tf.distribute.MirroredStrategy()
258
        if batch_size % num_devices != 0:
259
            raise ValueError(
260
                f"batch size {batch_size} can not be divided evenly "
261
                f"by the number of devices."
262
            )
263
    else:
264
        strategy = tf.distribute.get_strategy()
265
    with strategy.scope():
266
        model: tf.keras.Model = REGISTRY.build_model(
267
            config=dict(
268
                name=config["train"]["method"],
269
                moving_image_size=data_loader.moving_image_shape,
270
                fixed_image_size=data_loader.fixed_image_shape,
271
                index_size=data_loader.num_indices,
272
                labeled=config["dataset"]["labeled"],
273
                batch_size=batch_size,
274
                config=config["train"],
275
            )
276
        )
277
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
278
        model.compile(optimizer=optimizer)
279
        model.plot_model(output_dir=log_dir)
280
281
    # load weights
282
    if ckpt_path.endswith(".ckpt"):
283
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
284
        # skip warnings because of optimizers
285
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
286
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
287
    else:
288
        # for ckpts from ckpt manager callback
289
        _, _ = build_checkpoint_callback(
290
            model=model,
291
            dataset=dataset,
292
            log_dir=log_dir,
293
            save_period=config["train"]["save_period"],
294
            ckpt_path=ckpt_path,
295
        )
296
297
    # predict
298
    fixed_grid_ref = tf.expand_dims(
299
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0
300
    )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
301
    predict_on_dataset(
302
        dataset=dataset,
303
        fixed_grid_ref=fixed_grid_ref,
304
        model=model,
305
        model_method=config["train"]["method"],
306
        save_dir=os.path.join(log_dir, "test"),
307
        save_nifti=save_nifti,
308
        save_png=save_png,
309
    )
310
311
    # close the opened files in data loaders
312
    data_loader.close()
313
314
315
def main(args=None):
316
    """
317
    Entry point for predict script.
318
319
    :param args:
320
    """
321
    parser = argparse.ArgumentParser()
322
323
    parser.add_argument(
324
        "--gpu",
325
        "-g",
326
        help="GPU index for training."
327
        '-g "" for using CPU'
328
        '-g "0" for using GPU 0'
329
        '-g "0,1" for using GPU 0 and 1.',
330
        type=str,
331
        required=True,
332
    )
333
334
    parser.add_argument(
335
        "--gpu_allow_growth",
336
        "-gr",
337
        help="Prevent TensorFlow from reserving all available GPU memory",
338
        default=False,
339
    )
340
341
    parser.add_argument(
342
        "--num_workers",
343
        help="Number of CPUs to be used, <= 0 means unlimited.",
344
        type=int,
345
        default=1,
346
    )
347
348
    parser.add_argument(
349
        "--ckpt_path",
350
        "-k",
351
        help="Path of checkpointed model to load",
352
        default="",
353
        type=str,
354
        required=True,
355
    )
356
357
    parser.add_argument(
358
        "--mode",
359
        "-m",
360
        help="Define the split of data to be used for prediction."
361
        "train or valid or test",
362
        type=str,
363
        default="test",
364
        required=True,
365
    )
366
367
    parser.add_argument(
368
        "--batch_size", "-b", help="Batch size for predictions", default=1, type=int
369
    )
370
371
    parser.add_argument(
372
        "--log_dir", help="Path of log directory.", default="logs", type=str
373
    )
374
375
    parser.add_argument(
376
        "--exp_name", "-n", help="Name of the experiment.", default="", type=str
377
    )
378
379
    parser.add_argument("--save_nifti", dest="nifti", action="store_true")
380
    parser.add_argument("--no_nifti", dest="nifti", action="store_false")
381
    parser.set_defaults(nifti=True)
382
383
    parser.add_argument("--save_png", dest="png", action="store_true")
384
    parser.add_argument("--no_png", dest="png", action="store_false")
385
    parser.set_defaults(png=False)
386
387
    parser.add_argument(
388
        "--config_path",
389
        "-c",
390
        help="Path of config, must end with .yaml. Can pass multiple paths.",
391
        type=str,
392
        nargs="*",
393
        default="",
394
    )
395
396
    args = parser.parse_args(args)
397
398
    predict(
399
        gpu=args.gpu,
400
        ckpt_path=args.ckpt_path,
401
        num_workers=args.num_workers,
402
        gpu_allow_growth=args.gpu_allow_growth,
403
        mode=args.mode,
404
        batch_size=args.batch_size,
405
        log_dir=args.log_dir,
406
        exp_name=args.exp_name,
407
        config_path=args.config_path,
408
        save_nifti=args.nifti,
409
        save_png=args.png,
410
    )
411
412
413
if __name__ == "__main__":
414
    main()  # pragma: no cover
415