Passed
Pull Request — main (#724)
by Yunguan
01:28
created

deepreg.predict.predict()   C

Complexity

Conditions 7

Size

Total Lines 124
Code Lines 70

Duplication

Lines 13
Ratio 10.48 %

Importance

Changes 0
Metric Value
eloc 70
dl 13
loc 124
rs 6.5818
c 0
b 0
f 0
cc 7
nop 11

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
2
3
"""
4
Module to perform predictions on data using
5
command line interface.
6
"""
7
8
import argparse
9
import os
10
import shutil
11
from typing import Dict, List, Tuple, Union
12
13
import numpy as np
14
import tensorflow as tf
15
16
import deepreg.config.parser as config_parser
17
import deepreg.model.layer_util as layer_util
18
import deepreg.model.optimizer as opt
19
from deepreg import log
20
from deepreg.callback import build_checkpoint_callback
21
from deepreg.registry import REGISTRY
22
from deepreg.util import (
23
    build_dataset,
24
    build_log_dir,
25
    calculate_metrics,
26
    save_array,
27
    save_metric_dict,
28
)
29
30
logger = log.get(__name__)
31
32
33
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]:
34
    """
35
    Create directory for saving the paired data
36
37
    :param indices: indices of the pair, the last one is for label
38
    :param save_dir: directory of output
39
    :return: - save_dir, str, directory for saving the moving/fixed image
40
             - label_dir, str, directory for saving the rest outputs
41
    """
42
43
    # cast indices to string and init directory name
44
    pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]])
45
    pair_dir = os.path.join(save_dir, pair_index)
46
    os.makedirs(pair_dir, exist_ok=True)
47
48
    if indices[-1] >= 0:
49
        label_index = f"label_{indices[-1]}"
50
        label_dir = os.path.join(pair_dir, label_index)
51
        os.makedirs(label_dir, exist_ok=True)
52
    else:
53
        label_dir = pair_dir
54
55
    return pair_dir, label_dir
56
57
58
def predict_on_dataset(
59
    dataset: tf.data.Dataset,
60
    fixed_grid_ref: tf.Tensor,
61
    model: tf.keras.Model,
62
    model_method: str,
63
    save_dir: str,
64
    save_nifti: bool,
65
    save_png: bool,
66
):
67
    """
68
    Function to predict results from a dataset from some model
69
70
    :param dataset: where data is stored
71
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
72
    :param model: model to be used for prediction
73
    :param model_method: ddf / dvf / affine / conditional
74
    :param save_dir: path to store dir
75
    :param save_nifti: if true, outputs will be saved in nifti format
76
    :param save_png: if true, outputs will be saved in png format
77
    """
78
    # remove the save_dir in case it exists
79
    if os.path.exists(save_dir):
80
        shutil.rmtree(save_dir)  # pragma: no cover
81
82
    sample_index_strs = []
83
    metric_lists = []
84
    for _, inputs in enumerate(dataset):
85
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
86
        outputs = model.predict(x=inputs, batch_size=batch_size)
87
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)
88
89
        # convert to np arrays
90
        indices = indices.numpy()
91
        processed = {
92
            k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
93
            for k, v in processed.items()
94
        }
95
96
        # save images of inputs and outputs
97
        for sample_index in range(batch_size):
98
            # save label independent tensors under pair_dir, otherwise under label_dir
99
100
            # init output path
101
            indices_i = indices[sample_index, :].astype(int).tolist()
102
            pair_dir, label_dir = build_pair_output_path(
103
                indices=indices_i, save_dir=save_dir
104
            )
105
106
            for name, (arr, normalize, on_label) in processed.items():
107
                if name == "theta":
108
                    np.savetxt(
109
                        fname=os.path.join(pair_dir, "affine.txt"),
110
                        X=arr[sample_index, :, :],
111
                        delimiter=",",
112
                    )
113
                    continue
114
115
                arr_save_dir = label_dir if on_label else pair_dir
116
                save_array(
117
                    save_dir=arr_save_dir,
118
                    arr=arr[sample_index, :, :, :],
119
                    name=name,
120
                    normalize=normalize,  # label's value is already in [0, 1]
121
                    save_nifti=save_nifti,
122
                    save_png=save_png,
123
                    overwrite=arr_save_dir == label_dir,
124
                )
125
126
            # calculate metric
127
            sample_index_str = "_".join([str(x) for x in indices_i])
128
            if sample_index_str in sample_index_strs:  # pragma: no cover
129
                raise ValueError(
130
                    "Sample is repeated, maybe the dataset has been repeated."
131
                )
132
            sample_index_strs.append(sample_index_str)
133
134
            metric = calculate_metrics(
135
                fixed_image=processed["fixed_image"][0],
136
                fixed_label=processed["fixed_label"][0] if model.labeled else None,
137
                pred_fixed_image=processed["pred_fixed_image"][0],
138
                pred_fixed_label=processed["pred_fixed_label"][0]
139
                if model.labeled
140
                else None,
141
                fixed_grid_ref=fixed_grid_ref,
142
                sample_index=sample_index,
143
            )
144
            metric["pair_index"] = indices_i[:-1]
145
            metric["label_index"] = indices_i[-1]
146
            metric_lists.append(metric)
147
148
    # save metric
149
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
150
151
152
def build_config(
153
    config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str
154
) -> Tuple[Dict, str, str]:
155
    """
156
    Function to create new directory to log directory to store results.
157
158
    :param config_path: path of configuration files.
159
    :param log_dir: path of the log directory.
160
    :param exp_name: experiment name.
161
    :param ckpt_path: path where model is stored.
162
    :return: - config, configuration dictionary.
163
             - exp_name, path of the directory for saving outputs.
164
    """
165
166
    # init log directory
167
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
168
169
    # replace the ~ with user home path
170
    ckpt_path = os.path.expanduser(ckpt_path)
171
172
    # load config
173
    if config_path == "":
174
        # use default config, which should be provided in the log folder
175
        config = config_parser.load_configs(
176
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml"
177
        )
178
    else:
179
        # use customized config
180
        logger.warning(
181
            "Using customized configuration. "
182
            "The code might break if the config doesn't match the saved model."
183
        )
184
        config = config_parser.load_configs(config_path)
185
    return config, log_dir, ckpt_path
186
187
188
def predict(
189
    gpu: str,
190
    ckpt_path: str,
191
    mode: str,
192
    batch_size: int,
193
    exp_name: str,
194
    config_path: Union[str, List[str]],
195
    num_workers: int = 1,
196
    gpu_allow_growth: bool = True,
197
    save_nifti: bool = True,
198
    save_png: bool = True,
199
    log_dir: str = "logs",
200
):
201
    """
202
    Function to predict some metrics from the saved model and logging results.
203
204
    :param gpu: which env gpu to use.
205
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
206
    :param mode: train / valid / test, to define which split of dataset to be evaluated.
207
    :param batch_size: int, batch size to perform predictions.
208
    :param exp_name: name of the experiment.
209
    :param config_path: to overwrite the default config.
210
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
211
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
212
    :param save_nifti: if true, outputs will be saved in nifti format.
213
    :param save_png: if true, outputs will be saved in png format.
214
    :param log_dir: path of the log directory.
215
    """
216
217
    # env vars
218
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
219
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"
220 View Code Duplication
    if num_workers <= 0:  # pragma: no cover
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
221
        logger.info(
222
            "Limiting CPU usage by setting environment variables "
223
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
224
            "This may slow down the prediction. Please use --num_workers flag to modify the behavior. "
225
            "Setting to 0 or negative values will remove the limitation.",
226
            num_workers,
227
        )
228
        # limit CPU usage
229
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
230
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
231
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
232
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)
233
234
    # load config
235
    config, log_dir, ckpt_path = build_config(
236
        config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path
237
    )
238
    config["train"]["preprocess"]["batch_size"] = batch_size
239
240
    # data
241
    data_loader, dataset, _ = build_dataset(
242
        dataset_config=config["dataset"],
243
        preprocess_config=config["train"]["preprocess"],
244
        mode=mode,
245
        training=False,
246
        repeat=False,
247
    )
248
    assert data_loader is not None
249
250
    # use strategy to support multiple GPUs
251
    # the network is mirrored in each GPU so that we can use larger batch size
252
    # https://www.tensorflow.org/guide/distributed_training
253
    # only model, optimizer and metrics need to be defined inside the strategy
254
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
255
    if num_devices > 1:  # pragma: no cover
256
        strategy = tf.distribute.MirroredStrategy()
257
        if batch_size % num_devices != 0:
258
            raise ValueError(
259
                f"batch size {batch_size} can not be divided evenly "
260
                f"by the number of devices."
261
            )
262
    else:
263
        strategy = tf.distribute.get_strategy()
264
    with strategy.scope():
265
        model: tf.keras.Model = REGISTRY.build_model(
266
            config=dict(
267
                name=config["train"]["method"],
268
                moving_image_size=data_loader.moving_image_shape,
269
                fixed_image_size=data_loader.fixed_image_shape,
270
                index_size=data_loader.num_indices,
271
                labeled=config["dataset"]["labeled"],
272
                batch_size=batch_size,
273
                config=config["train"],
274
            )
275
        )
276
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
277
        model.compile(optimizer=optimizer)
278
        model.plot_model(output_dir=log_dir)
279
280
    # load weights
281
    if ckpt_path.endswith(".ckpt"):
282
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
283
        # skip warnings because of optimizers
284
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
285
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
286
    else:
287
        # for ckpts from ckpt manager callback
288
        _, _ = build_checkpoint_callback(
289
            model=model,
290
            dataset=dataset,
291
            log_dir=log_dir,
292
            save_period=config["train"]["save_period"],
293
            ckpt_path=ckpt_path,
294
        )
295
296
    # predict
297
    fixed_grid_ref = tf.expand_dims(
298
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0
299
    )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
300
    predict_on_dataset(
301
        dataset=dataset,
302
        fixed_grid_ref=fixed_grid_ref,
303
        model=model,
304
        model_method=config["train"]["method"],
305
        save_dir=os.path.join(log_dir, "test"),
306
        save_nifti=save_nifti,
307
        save_png=save_png,
308
    )
309
310
    # close the opened files in data loaders
311
    data_loader.close()
312
313
314
def main(args=None):
315
    """
316
    Entry point for predict script.
317
318
    :param args:
319
    """
320
    parser = argparse.ArgumentParser()
321
322
    parser.add_argument(
323
        "--gpu",
324
        "-g",
325
        help="GPU index for training."
326
        '-g "" for using CPU'
327
        '-g "0" for using GPU 0'
328
        '-g "0,1" for using GPU 0 and 1.',
329
        type=str,
330
        required=True,
331
    )
332
333
    parser.add_argument(
334
        "--gpu_allow_growth",
335
        "-gr",
336
        help="Prevent TensorFlow from reserving all available GPU memory",
337
        default=False,
338
    )
339
340
    parser.add_argument(
341
        "--num_workers",
342
        help="Number of CPUs to be used, <= 0 means unlimited.",
343
        type=int,
344
        default=1,
345
    )
346
347
    parser.add_argument(
348
        "--ckpt_path",
349
        "-k",
350
        help="Path of checkpointed model to load",
351
        default="",
352
        type=str,
353
        required=True,
354
    )
355
356
    parser.add_argument(
357
        "--mode",
358
        "-m",
359
        help="Define the split of data to be used for prediction."
360
        "train or valid or test",
361
        type=str,
362
        default="test",
363
        required=True,
364
    )
365
366
    parser.add_argument(
367
        "--batch_size", "-b", help="Batch size for predictions", default=1, type=int
368
    )
369
370
    parser.add_argument(
371
        "--log_dir", help="Path of log directory.", default="logs", type=str
372
    )
373
374
    parser.add_argument(
375
        "--exp_name", "-n", help="Name of the experiment.", default="", type=str
376
    )
377
378
    parser.add_argument("--save_nifti", dest="nifti", action="store_true")
379
    parser.add_argument("--no_nifti", dest="nifti", action="store_false")
380
    parser.set_defaults(nifti=True)
381
382
    parser.add_argument("--save_png", dest="png", action="store_true")
383
    parser.add_argument("--no_png", dest="png", action="store_false")
384
    parser.set_defaults(png=False)
385
386
    parser.add_argument(
387
        "--config_path",
388
        "-c",
389
        help="Path of config, must end with .yaml. Can pass multiple paths.",
390
        type=str,
391
        nargs="*",
392
        default="",
393
    )
394
395
    args = parser.parse_args(args)
396
397
    predict(
398
        gpu=args.gpu,
399
        ckpt_path=args.ckpt_path,
400
        num_workers=args.num_workers,
401
        gpu_allow_growth=args.gpu_allow_growth,
402
        mode=args.mode,
403
        batch_size=args.batch_size,
404
        log_dir=args.log_dir,
405
        exp_name=args.exp_name,
406
        config_path=args.config_path,
407
        save_nifti=args.nifti,
408
        save_png=args.png,
409
    )
410
411
412
if __name__ == "__main__":
413
    main()  # pragma: no cover
414