deepreg.predict.predict()   C
last analyzed

Complexity

Conditions 8

Size

Total Lines 127
Code Lines 71

Duplication

Lines 13
Ratio 10.24 %

Importance

Changes 0
Metric Value
eloc 71
dl 13
loc 127
rs 6.0824
c 0
b 0
f 0
cc 8
nop 11

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
2
3
"""
4
Module to perform predictions on data using
5
command line interface.
6
"""
7
8
import argparse
9
import os
10
import shutil
11
from typing import Dict, List, Tuple, Union
12
13
import numpy as np
14
import tensorflow as tf
15
16
import deepreg.config.parser as config_parser
17
import deepreg.model.layer_util as layer_util
18
import deepreg.model.optimizer as opt
19
from deepreg import log
20
from deepreg.callback import build_checkpoint_callback
21
from deepreg.registry import REGISTRY
22
from deepreg.util import (
23
    build_dataset,
24
    build_log_dir,
25
    calculate_metrics,
26
    save_array,
27
    save_metric_dict,
28
)
29
30
logger = log.get(__name__)
31
32
33
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]:
34
    """
35
    Create directory for saving the paired data
36
37
    :param indices: indices of the pair, the last one is for label
38
    :param save_dir: directory of output
39
    :return: - save_dir, str, directory for saving the moving/fixed image
40
             - label_dir, str, directory for saving the rest outputs
41
    """
42
43
    # cast indices to string and init directory name
44
    pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]])
45
    pair_dir = os.path.join(save_dir, pair_index)
46
    os.makedirs(pair_dir, exist_ok=True)
47
48
    if indices[-1] >= 0:
49
        label_index = f"label_{indices[-1]}"
50
        label_dir = os.path.join(pair_dir, label_index)
51
        os.makedirs(label_dir, exist_ok=True)
52
    else:
53
        label_dir = pair_dir
54
55
    return pair_dir, label_dir
56
57
58
def predict_on_dataset(
59
    dataset: tf.data.Dataset,
60
    fixed_grid_ref: tf.Tensor,
61
    model: tf.keras.Model,
62
    save_dir: str,
63
    save_nifti: bool,
64
    save_png: bool,
65
):
66
    """
67
    Function to predict results from a dataset from some model
68
69
    :param dataset: where data is stored
70
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
71
    :param model: model to be used for prediction
72
    :param save_dir: path to store dir
73
    :param save_nifti: if true, outputs will be saved in nifti format
74
    :param save_png: if true, outputs will be saved in png format
75
    """
76
    # remove the save_dir in case it exists
77
    if os.path.exists(save_dir):
78
        shutil.rmtree(save_dir)  # pragma: no cover
79
80
    sample_index_strs = []
81
    metric_lists = []
82
    for _, inputs in enumerate(dataset):
83
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
84
        outputs = model.predict(x=inputs, batch_size=batch_size)
85
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)
86
87
        # convert to np arrays
88
        indices = indices.numpy()
89
        processed = {
90
            k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
91
            for k, v in processed.items()
92
        }
93
94
        # save images of inputs and outputs
95
        for sample_index in range(batch_size):
96
            # save label independent tensors under pair_dir, otherwise under label_dir
97
98
            # init output path
99
            indices_i = indices[sample_index, :].astype(int).tolist()
100
            pair_dir, label_dir = build_pair_output_path(
101
                indices=indices_i, save_dir=save_dir
102
            )
103
104
            for name, (arr, normalize, on_label) in processed.items():
105
                if name == "theta":
106
                    np.savetxt(
107
                        fname=os.path.join(pair_dir, "affine.txt"),
108
                        X=arr[sample_index, :, :],
109
                        delimiter=",",
110
                    )
111
                    continue
112
113
                arr_save_dir = label_dir if on_label else pair_dir
114
                save_array(
115
                    save_dir=arr_save_dir,
116
                    arr=arr[sample_index, :, :, :],
117
                    name=name,
118
                    normalize=normalize,  # label's value is already in [0, 1]
119
                    save_nifti=save_nifti,
120
                    save_png=save_png,
121
                    overwrite=arr_save_dir == label_dir,
122
                )
123
124
            # calculate metric
125
            sample_index_str = "_".join([str(x) for x in indices_i])
126
            if sample_index_str in sample_index_strs:  # pragma: no cover
127
                raise ValueError(
128
                    "Sample is repeated, maybe the dataset has been repeated."
129
                )
130
            sample_index_strs.append(sample_index_str)
131
132
            metric = calculate_metrics(
133
                fixed_image=processed["fixed_image"][0],
134
                fixed_label=processed["fixed_label"][0] if model.labeled else None,
135
                pred_fixed_image=processed["pred_fixed_image"][0],
136
                pred_fixed_label=processed["pred_fixed_label"][0]
137
                if model.labeled
138
                else None,
139
                fixed_grid_ref=fixed_grid_ref,
140
                sample_index=sample_index,
141
            )
142
            metric["pair_index"] = indices_i[:-1]
143
            metric["label_index"] = indices_i[-1]
144
            metric_lists.append(metric)
145
146
    # save metric
147
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
148
149
150
def build_config(
151
    config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str
152
) -> Tuple[Dict, str, str]:
153
    """
154
    Function to create new directory to log directory to store results.
155
156
    :param config_path: path of configuration files.
157
    :param log_dir: path of the log directory.
158
    :param exp_name: experiment name.
159
    :param ckpt_path: path where model is stored.
160
    :return: - config, configuration dictionary.
161
             - exp_name, path of the directory for saving outputs.
162
    """
163
164
    # init log directory
165
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
166
167
    # replace the ~ with user home path
168
    ckpt_path = os.path.expanduser(ckpt_path)
169
170
    # load config
171
    if config_path == "":
172
        # use default config, which should be provided in the log folder
173
        config = config_parser.load_configs(
174
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml"
175
        )
176
    else:
177
        # use customized config
178
        logger.warning(
179
            "Using customized configuration. "
180
            "The code might break if the config doesn't match the saved model."
181
        )
182
        config = config_parser.load_configs(config_path)
183
    return config, log_dir, ckpt_path
184
185
186
def predict(
187
    gpu: str,
188
    ckpt_path: str,
189
    split: str,
190
    batch_size: int,
191
    exp_name: str,
192
    config_path: Union[str, List[str]],
193
    num_workers: int = 1,
194
    gpu_allow_growth: bool = True,
195
    save_nifti: bool = True,
196
    save_png: bool = True,
197
    log_dir: str = "logs",
198
):
199
    """
200
    Function to predict some metrics from the saved model and logging results.
201
202
    :param gpu: which env gpu to use.
203
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
204
    :param split: train / valid / test, to define the split to be evaluated.
205
    :param batch_size: int, batch size to perform predictions.
206
    :param exp_name: name of the experiment.
207
    :param config_path: to overwrite the default config.
208
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
209
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
210
    :param save_nifti: if true, outputs will be saved in nifti format.
211
    :param save_png: if true, outputs will be saved in png format.
212
    :param log_dir: path of the log directory.
213
    """
214
215
    # env vars
216
    if gpu is not None:
217
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
218
        os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = (
219
            "false" if gpu_allow_growth else "true"
220
        )
221 View Code Duplication
    if num_workers <= 0:  # pragma: no cover
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
222
        logger.info(
223
            "Limiting CPU usage by setting environment variables "
224
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
225
            "This may slow down the prediction. "
226
            "Please use --num_workers flag to modify the behavior. "
227
            "Setting to 0 or negative values will remove the limitation.",
228
            num_workers,
229
        )
230
        # limit CPU usage
231
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
232
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
233
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
234
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)
235
236
    # load config
237
    config, log_dir, ckpt_path = build_config(
238
        config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path
239
    )
240
    config["train"]["preprocess"]["batch_size"] = batch_size
241
242
    # data
243
    data_loader, dataset, _ = build_dataset(
244
        dataset_config=config["dataset"],
245
        preprocess_config=config["train"]["preprocess"],
246
        split=split,
247
        training=False,
248
        repeat=False,
249
    )
250
    assert data_loader is not None
251
252
    # use strategy to support multiple GPUs
253
    # the network is mirrored in each GPU so that we can use larger batch size
254
    # https://www.tensorflow.org/guide/distributed_training
255
    # only model, optimizer and metrics need to be defined inside the strategy
256
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
257
    if num_devices > 1:  # pragma: no cover
258
        strategy = tf.distribute.MirroredStrategy()
259
        if batch_size % num_devices != 0:
260
            raise ValueError(
261
                f"batch size {batch_size} can not be divided evenly "
262
                f"by the number of devices."
263
            )
264
    else:
265
        strategy = tf.distribute.get_strategy()
266
    with strategy.scope():
267
        model: tf.keras.Model = REGISTRY.build_model(
268
            config=dict(
269
                name=config["train"]["method"],
270
                moving_image_size=data_loader.moving_image_shape,
271
                fixed_image_size=data_loader.fixed_image_shape,
272
                index_size=data_loader.num_indices,
273
                labeled=config["dataset"][split]["labeled"],
274
                batch_size=batch_size,
275
                config=config["train"],
276
            )
277
        )
278
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
279
        model.compile(optimizer=optimizer)
280
        model.plot_model(output_dir=log_dir)
281
282
    # load weights
283
    if ckpt_path.endswith(".ckpt"):
284
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
285
        # skip warnings because of optimizers
286
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
287
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
288
    else:
289
        # for ckpts from ckpt manager callback
290
        _, _ = build_checkpoint_callback(
291
            model=model,
292
            dataset=dataset,
293
            log_dir=log_dir,
294
            save_period=config["train"]["save_period"],
295
            ckpt_path=ckpt_path,
296
        )
297
298
    # predict
299
    fixed_grid_ref = tf.expand_dims(
300
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0
301
    )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
302
    predict_on_dataset(
303
        dataset=dataset,
304
        fixed_grid_ref=fixed_grid_ref,
305
        model=model,
306
        save_dir=os.path.join(log_dir, "test"),
307
        save_nifti=save_nifti,
308
        save_png=save_png,
309
    )
310
311
    # close the opened files in data loaders
312
    data_loader.close()
313
314
315
def main(args=None):
316
    """
317
    Entry point for predict script.
318
319
    :param args:
320
    """
321
    parser = argparse.ArgumentParser()
322
323
    parser.add_argument(
324
        "--gpu",
325
        "-g",
326
        help="GPU index for training."
327
        "-g for using GPU remotely"
328
        '-g "" for using CPU'
329
        '-g "0" for using GPU 0'
330
        '-g "0,1" for using GPU 0 and 1.',
331
        type=str,
332
        required=False,
333
    )
334
335
    parser.add_argument(
336
        "--gpu_allow_growth",
337
        "-gr",
338
        help="Prevent TensorFlow from reserving all available GPU memory",
339
        default=False,
340
    )
341
342
    parser.add_argument(
343
        "--num_workers",
344
        help="Number of CPUs to be used, <= 0 means unlimited.",
345
        type=int,
346
        default=1,
347
    )
348
349
    parser.add_argument(
350
        "--ckpt_path",
351
        "-k",
352
        help="Path of checkpointed model to load",
353
        default="",
354
        type=str,
355
        required=True,
356
    )
357
358
    parser.add_argument(
359
        "--split",
360
        help="Define the split of data to be used for prediction: "
361
        "train or valid or test",
362
        type=str,
363
        required=True,
364
    )
365
366
    parser.add_argument(
367
        "--batch_size", "-b", help="Batch size for predictions", default=1, type=int
368
    )
369
370
    parser.add_argument(
371
        "--log_dir", help="Path of log directory.", default="logs", type=str
372
    )
373
374
    parser.add_argument(
375
        "--exp_name", "-n", help="Name of the experiment.", default="", type=str
376
    )
377
378
    parser.add_argument("--save_nifti", dest="nifti", action="store_true")
379
    parser.add_argument("--no_nifti", dest="nifti", action="store_false")
380
    parser.set_defaults(nifti=True)
381
382
    parser.add_argument("--save_png", dest="png", action="store_true")
383
    parser.add_argument("--no_png", dest="png", action="store_false")
384
    parser.set_defaults(png=False)
385
386
    parser.add_argument(
387
        "--config_path",
388
        "-c",
389
        help="Path of config, must end with .yaml. Can pass multiple paths.",
390
        type=str,
391
        nargs="*",
392
        default="",
393
    )
394
395
    args = parser.parse_args(args)
396
397
    predict(
398
        gpu=args.gpu,
399
        ckpt_path=args.ckpt_path,
400
        num_workers=args.num_workers,
401
        gpu_allow_growth=args.gpu_allow_growth,
402
        split=args.split,
403
        batch_size=args.batch_size,
404
        log_dir=args.log_dir,
405
        exp_name=args.exp_name,
406
        config_path=args.config_path,
407
        save_nifti=args.nifti,
408
        save_png=args.png,
409
    )
410
411
412
if __name__ == "__main__":
413
    main()  # pragma: no cover
414