deepreg.predict.main()   B
last analyzed

Complexity

Conditions 1

Size

Total Lines 94
Code Lines 62

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 62
dl 0
loc 94
rs 8.2436
c 0
b 0
f 0
cc 1
nop 1

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# coding=utf-8
2
3
"""
4
Module to perform predictions on data using
5
command line interface.
6
"""
7
8
import argparse
9
import os
10
import shutil
11
from typing import Dict, List, Tuple, Union
12
13
import numpy as np
14
import tensorflow as tf
15
16
import deepreg.config.parser as config_parser
17
import deepreg.model.layer_util as layer_util
18
import deepreg.model.optimizer as opt
19
from deepreg import log
20
from deepreg.callback import build_checkpoint_callback
21
from deepreg.registry import REGISTRY
22
from deepreg.util import (
23
    build_dataset,
24
    build_log_dir,
25
    calculate_metrics,
26
    save_array,
27
    save_metric_dict,
28
)
29
30
logger = log.get(__name__)
31
32
33
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]:
34
    """
35
    Create directory for saving the paired data
36
37
    :param indices: indices of the pair, the last one is for label
38
    :param save_dir: directory of output
39
    :return: - save_dir, str, directory for saving the moving/fixed image
40
             - label_dir, str, directory for saving the rest outputs
41
    """
42
43
    # cast indices to string and init directory name
44
    pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]])
45
    pair_dir = os.path.join(save_dir, pair_index)
46
    os.makedirs(pair_dir, exist_ok=True)
47
48
    if indices[-1] >= 0:
49
        label_index = f"label_{indices[-1]}"
50
        label_dir = os.path.join(pair_dir, label_index)
51
        os.makedirs(label_dir, exist_ok=True)
52
    else:
53
        label_dir = pair_dir
54
55
    return pair_dir, label_dir
56
57
58
def predict_on_dataset(
59
    dataset: tf.data.Dataset,
60
    fixed_grid_ref: tf.Tensor,
61
    model: tf.keras.Model,
62
    save_dir: str,
63
    save_nifti: bool,
64
    save_png: bool,
65
):
66
    """
67
    Function to predict results from a dataset from some model
68
69
    :param dataset: where data is stored
70
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
71
    :param model: model to be used for prediction
72
    :param save_dir: path to store dir
73
    :param save_nifti: if true, outputs will be saved in nifti format
74
    :param save_png: if true, outputs will be saved in png format
75
    """
76
    # remove the save_dir in case it exists
77
    if os.path.exists(save_dir):
78
        shutil.rmtree(save_dir)  # pragma: no cover
79
80
    sample_index_strs = []
81
    metric_lists = []
82
    for _, inputs in enumerate(dataset):
83
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
84
        outputs = model.predict(x=inputs, batch_size=batch_size)
85
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)
86
87
        # convert to np arrays
88
        indices = indices.numpy()
89
        processed = {
90
            k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
91
            for k, v in processed.items()
92
        }
93
94
        # save images of inputs and outputs
95
        for sample_index in range(batch_size):
96
            # save label independent tensors under pair_dir, otherwise under label_dir
97
98
            # init output path
99
            indices_i = indices[sample_index, :].astype(int).tolist()
100
            pair_dir, label_dir = build_pair_output_path(
101
                indices=indices_i, save_dir=save_dir
102
            )
103
104
            for name, (arr, normalize, on_label) in processed.items():
105
                if name == "theta":
106
                    np.savetxt(
107
                        fname=os.path.join(pair_dir, "affine.txt"),
108
                        X=arr[sample_index, :, :],
109
                        delimiter=",",
110
                    )
111
                    continue
112
113
                arr_save_dir = label_dir if on_label else pair_dir
114
                save_array(
115
                    save_dir=arr_save_dir,
116
                    arr=arr[sample_index, :, :, :],
117
                    name=name,
118
                    normalize=normalize,  # label's value is already in [0, 1]
119
                    save_nifti=save_nifti,
120
                    save_png=save_png,
121
                    overwrite=arr_save_dir == label_dir,
122
                )
123
124
            # calculate metric
125
            sample_index_str = "_".join([str(x) for x in indices_i])
126
            if sample_index_str in sample_index_strs:  # pragma: no cover
127
                raise ValueError(
128
                    "Sample is repeated, maybe the dataset has been repeated."
129
                )
130
            sample_index_strs.append(sample_index_str)
131
132
            metric = calculate_metrics(
133
                fixed_image=processed["fixed_image"][0],
134
                fixed_label=processed["fixed_label"][0] if model.labeled else None,
135
                pred_fixed_image=processed["pred_fixed_image"][0],
136
                pred_fixed_label=processed["pred_fixed_label"][0]
137
                if model.labeled
138
                else None,
139
                fixed_grid_ref=fixed_grid_ref,
140
                sample_index=sample_index,
141
            )
142
            metric["pair_index"] = indices_i[:-1]
143
            metric["label_index"] = indices_i[-1]
144
            metric_lists.append(metric)
145
146
    # save metric
147
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
148
149
150
def build_config(
151
    config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str
152
) -> Tuple[Dict, str, str]:
153
    """
154
    Function to create new directory to log directory to store results.
155
156
    :param config_path: path of configuration files.
157
    :param log_dir: path of the log directory.
158
    :param exp_name: experiment name.
159
    :param ckpt_path: path where model is stored.
160
    :return: - config, configuration dictionary.
161
             - exp_name, path of the directory for saving outputs.
162
    """
163
164
    # init log directory
165
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
166
167
    # replace the ~ with user home path
168
    ckpt_path = os.path.expanduser(ckpt_path)
169
170
    # load config
171
    if config_path == "":
172
        # use default config, which should be provided in the log folder
173
        config = config_parser.load_configs(
174
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml"
175
        )
176
    else:
177
        # use customized config
178
        logger.warning(
179
            "Using customized configuration. "
180
            "The code might break if the config doesn't match the saved model."
181
        )
182
        config = config_parser.load_configs(config_path)
183
    return config, log_dir, ckpt_path
184
185
186
def predict(
187
    gpu: str,
188
    ckpt_path: str,
189
    split: str,
190
    batch_size: int,
191
    exp_name: str,
192
    config_path: Union[str, List[str]],
193
    num_workers: int = 1,
194
    gpu_allow_growth: bool = True,
195
    save_nifti: bool = True,
196
    save_png: bool = True,
197
    log_dir: str = "logs",
198
):
199
    """
200
    Function to predict some metrics from the saved model and logging results.
201
202
    :param gpu: which env gpu to use.
203
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
204
    :param split: train / valid / test, to define the split to be evaluated.
205
    :param batch_size: int, batch size to perform predictions.
206
    :param exp_name: name of the experiment.
207
    :param config_path: to overwrite the default config.
208
    :param num_workers: number of cpu cores to be used, <=0 means not limited.
209
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
210
    :param save_nifti: if true, outputs will be saved in nifti format.
211
    :param save_png: if true, outputs will be saved in png format.
212
    :param log_dir: path of the log directory.
213
    """
214
215
    # env vars
216
    if gpu is not None:
217
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
218
        os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = (
219
            "false" if gpu_allow_growth else "true"
220
        )
221 View Code Duplication
    if num_workers <= 0:  # pragma: no cover
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
222
        logger.info(
223
            "Limiting CPU usage by setting environment variables "
224
            "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. "
225
            "This may slow down the prediction. "
226
            "Please use --num_workers flag to modify the behavior. "
227
            "Setting to 0 or negative values will remove the limitation.",
228
            num_workers,
229
        )
230
        # limit CPU usage
231
        # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232
232
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
233
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
234
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)
235
236
    # load config
237
    config, log_dir, ckpt_path = build_config(
238
        config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path
239
    )
240
    config["train"]["preprocess"]["batch_size"] = batch_size
241
242
    # data
243
    data_loader, dataset, _ = build_dataset(
244
        dataset_config=config["dataset"],
245
        preprocess_config=config["train"]["preprocess"],
246
        split=split,
247
        training=False,
248
        repeat=False,
249
    )
250
    assert data_loader is not None
251
252
    # use strategy to support multiple GPUs
253
    # the network is mirrored in each GPU so that we can use larger batch size
254
    # https://www.tensorflow.org/guide/distributed_training
255
    # only model, optimizer and metrics need to be defined inside the strategy
256
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
257
    if num_devices > 1:  # pragma: no cover
258
        strategy = tf.distribute.MirroredStrategy()
259
        if batch_size % num_devices != 0:
260
            raise ValueError(
261
                f"batch size {batch_size} can not be divided evenly "
262
                f"by the number of devices."
263
            )
264
    else:
265
        strategy = tf.distribute.get_strategy()
266
    with strategy.scope():
267
        model: tf.keras.Model = REGISTRY.build_model(
268
            config=dict(
269
                name=config["train"]["method"],
270
                moving_image_size=data_loader.moving_image_shape,
271
                fixed_image_size=data_loader.fixed_image_shape,
272
                index_size=data_loader.num_indices,
273
                labeled=config["dataset"][split]["labeled"],
274
                batch_size=batch_size,
275
                config=config["train"],
276
            )
277
        )
278
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
279
        model.compile(optimizer=optimizer)
280
        model.plot_model(output_dir=log_dir)
281
282
    # load weights
283
    if ckpt_path.endswith(".ckpt"):
284
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
285
        # skip warnings because of optimizers
286
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
287
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
288
    else:
289
        # for ckpts from ckpt manager callback
290
        _, _ = build_checkpoint_callback(
291
            model=model,
292
            dataset=dataset,
293
            log_dir=log_dir,
294
            save_period=config["train"]["save_period"],
295
            ckpt_path=ckpt_path,
296
        )
297
298
    # predict
299
    fixed_grid_ref = tf.expand_dims(
300
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0
301
    )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
302
    predict_on_dataset(
303
        dataset=dataset,
304
        fixed_grid_ref=fixed_grid_ref,
305
        model=model,
306
        save_dir=os.path.join(log_dir, "test"),
307
        save_nifti=save_nifti,
308
        save_png=save_png,
309
    )
310
311
    # close the opened files in data loaders
312
    data_loader.close()
313
314
315
def main(args=None):
316
    """
317
    Entry point for predict script.
318
319
    :param args:
320
    """
321
    parser = argparse.ArgumentParser()
322
323
    parser.add_argument(
324
        "--gpu",
325
        "-g",
326
        help="GPU index for training."
327
        "-g for using GPU remotely"
328
        '-g "" for using CPU'
329
        '-g "0" for using GPU 0'
330
        '-g "0,1" for using GPU 0 and 1.',
331
        type=str,
332
        required=False,
333
    )
334
335
    parser.add_argument(
336
        "--gpu_allow_growth",
337
        "-gr",
338
        help="Prevent TensorFlow from reserving all available GPU memory",
339
        default=False,
340
    )
341
342
    parser.add_argument(
343
        "--num_workers",
344
        help="Number of CPUs to be used, <= 0 means unlimited.",
345
        type=int,
346
        default=1,
347
    )
348
349
    parser.add_argument(
350
        "--ckpt_path",
351
        "-k",
352
        help="Path of checkpointed model to load",
353
        default="",
354
        type=str,
355
        required=True,
356
    )
357
358
    parser.add_argument(
359
        "--split",
360
        help="Define the split of data to be used for prediction: "
361
        "train or valid or test",
362
        type=str,
363
        required=True,
364
    )
365
366
    parser.add_argument(
367
        "--batch_size", "-b", help="Batch size for predictions", default=1, type=int
368
    )
369
370
    parser.add_argument(
371
        "--log_dir", help="Path of log directory.", default="logs", type=str
372
    )
373
374
    parser.add_argument(
375
        "--exp_name", "-n", help="Name of the experiment.", default="", type=str
376
    )
377
378
    parser.add_argument("--save_nifti", dest="nifti", action="store_true")
379
    parser.add_argument("--no_nifti", dest="nifti", action="store_false")
380
    parser.set_defaults(nifti=True)
381
382
    parser.add_argument("--save_png", dest="png", action="store_true")
383
    parser.add_argument("--no_png", dest="png", action="store_false")
384
    parser.set_defaults(png=False)
385
386
    parser.add_argument(
387
        "--config_path",
388
        "-c",
389
        help="Path of config, must end with .yaml. Can pass multiple paths.",
390
        type=str,
391
        nargs="*",
392
        default="",
393
    )
394
395
    args = parser.parse_args(args)
396
397
    predict(
398
        gpu=args.gpu,
399
        ckpt_path=args.ckpt_path,
400
        num_workers=args.num_workers,
401
        gpu_allow_growth=args.gpu_allow_growth,
402
        split=args.split,
403
        batch_size=args.batch_size,
404
        log_dir=args.log_dir,
405
        exp_name=args.exp_name,
406
        config_path=args.config_path,
407
        save_nifti=args.nifti,
408
        save_png=args.png,
409
    )
410
411
412
if __name__ == "__main__":
413
    main()  # pragma: no cover
414