Passed
Pull Request — main (#724)
by Yunguan
01:24
created

deepreg.predict.predict()   C

Complexity

Conditions 7

Size

Total Lines 117
Code Lines 67

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 67
dl 0
loc 117
rs 6.68
c 0
b 0
f 0
cc 7
nop 11

How to fix   Long Method    Many Parameters   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Many Parameters

Methods with many parameters are not only hard to understand, but their parameters also often become inconsistent when you need more, or different data.

There are several approaches to avoid long parameter lists:

1
# coding=utf-8
2
3
"""
4
Module to perform predictions on data using
5
command line interface.
6
"""
7
8
import argparse
9
import os
10
import shutil
11
from typing import Dict, List, Tuple, Union
12
13
import numpy as np
14
import tensorflow as tf
15
16
import deepreg.config.parser as config_parser
17
import deepreg.model.layer_util as layer_util
18
import deepreg.model.optimizer as opt
19
from deepreg import log
20
from deepreg.callback import build_checkpoint_callback
21
from deepreg.registry import REGISTRY
22
from deepreg.util import (
23
    build_dataset,
24
    build_log_dir,
25
    calculate_metrics,
26
    save_array,
27
    save_metric_dict,
28
)
29
30
logger = log.get(__name__)
31
32
33
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]:
34
    """
35
    Create directory for saving the paired data
36
37
    :param indices: indices of the pair, the last one is for label
38
    :param save_dir: directory of output
39
    :return: - save_dir, str, directory for saving the moving/fixed image
40
             - label_dir, str, directory for saving the rest outputs
41
    """
42
43
    # cast indices to string and init directory name
44
    pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]])
45
    pair_dir = os.path.join(save_dir, pair_index)
46
    os.makedirs(pair_dir, exist_ok=True)
47
48
    if indices[-1] >= 0:
49
        label_index = f"label_{indices[-1]}"
50
        label_dir = os.path.join(pair_dir, label_index)
51
        os.makedirs(label_dir, exist_ok=True)
52
    else:
53
        label_dir = pair_dir
54
55
    return pair_dir, label_dir
56
57
58
def predict_on_dataset(
59
    dataset: tf.data.Dataset,
60
    fixed_grid_ref: tf.Tensor,
61
    model: tf.keras.Model,
62
    model_method: str,
63
    save_dir: str,
64
    save_nifti: bool,
65
    save_png: bool,
66
):
67
    """
68
    Function to predict results from a dataset from some model
69
70
    :param dataset: where data is stored
71
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
72
    :param model: model to be used for prediction
73
    :param model_method: ddf / dvf / affine / conditional
74
    :param save_dir: path to store dir
75
    :param save_nifti: if true, outputs will be saved in nifti format
76
    :param save_png: if true, outputs will be saved in png format
77
    """
78
    # remove the save_dir in case it exists
79
    if os.path.exists(save_dir):
80
        shutil.rmtree(save_dir)  # pragma: no cover
81
82
    sample_index_strs = []
83
    metric_lists = []
84
    for _, inputs in enumerate(dataset):
85
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
86
        outputs = model.predict(x=inputs, batch_size=batch_size)
87
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)
88
89
        # convert to np arrays
90
        indices = indices.numpy()
91
        processed = {
92
            k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
93
            for k, v in processed.items()
94
        }
95
96
        # save images of inputs and outputs
97
        for sample_index in range(batch_size):
98
            # save label independent tensors under pair_dir, otherwise under label_dir
99
100
            # init output path
101
            indices_i = indices[sample_index, :].astype(int).tolist()
102
            pair_dir, label_dir = build_pair_output_path(
103
                indices=indices_i, save_dir=save_dir
104
            )
105
106
            for name, (arr, normalize, on_label) in processed.items():
107
                if name == "theta":
108
                    np.savetxt(
109
                        fname=os.path.join(pair_dir, "affine.txt"),
110
                        X=arr[sample_index, :, :],
111
                        delimiter=",",
112
                    )
113
                    continue
114
115
                arr_save_dir = label_dir if on_label else pair_dir
116
                save_array(
117
                    save_dir=arr_save_dir,
118
                    arr=arr[sample_index, :, :, :],
119
                    name=name,
120
                    normalize=normalize,  # label's value is already in [0, 1]
121
                    save_nifti=save_nifti,
122
                    save_png=save_png,
123
                    overwrite=arr_save_dir == label_dir,
124
                )
125
126
            # calculate metric
127
            sample_index_str = "_".join([str(x) for x in indices_i])
128
            if sample_index_str in sample_index_strs:  # pragma: no cover
129
                raise ValueError(
130
                    "Sample is repeated, maybe the dataset has been repeated."
131
                )
132
            sample_index_strs.append(sample_index_str)
133
134
            metric = calculate_metrics(
135
                fixed_image=processed["fixed_image"][0],
136
                fixed_label=processed["fixed_label"][0] if model.labeled else None,
137
                pred_fixed_image=processed["pred_fixed_image"][0],
138
                pred_fixed_label=processed["pred_fixed_label"][0]
139
                if model.labeled
140
                else None,
141
                fixed_grid_ref=fixed_grid_ref,
142
                sample_index=sample_index,
143
            )
144
            metric["pair_index"] = indices_i[:-1]
145
            metric["label_index"] = indices_i[-1]
146
            metric_lists.append(metric)
147
148
    # save metric
149
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
150
151
152
def build_config(
153
    config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str
154
) -> Tuple[Dict, str, str]:
155
    """
156
    Function to create new directory to log directory to store results.
157
158
    :param config_path: path of configuration files.
159
    :param log_dir: path of the log directory.
160
    :param exp_name: experiment name.
161
    :param ckpt_path: path where model is stored.
162
    :return: - config, configuration dictionary.
163
             - exp_name, path of the directory for saving outputs.
164
    """
165
166
    # init log directory
167
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
168
169
    # replace the ~ with user home path
170
    ckpt_path = os.path.expanduser(ckpt_path)
171
172
    # load config
173
    if config_path == "":
174
        # use default config, which should be provided in the log folder
175
        config = config_parser.load_configs(
176
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml"
177
        )
178
    else:
179
        # use customized config
180
        logger.warning(
181
            "Using customized configuration. "
182
            "The code might break if the config doesn't match the saved model."
183
        )
184
        config = config_parser.load_configs(config_path)
185
    return config, log_dir, ckpt_path
186
187
188
def predict(
189
    gpu: str,
190
    ckpt_path: str,
191
    mode: str,
192
    batch_size: int,
193
    exp_name: str,
194
    config_path: Union[str, List[str]],
195
    num_workers: int = 1,
196
    gpu_allow_growth: bool = True,
197
    save_nifti: bool = True,
198
    save_png: bool = True,
199
    log_dir: str = "logs",
200
):
201
    """
202
    Function to predict some metrics from the saved model and logging results.
203
204
    :param gpu: which env gpu to use.
205
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
206
    :param mode: train / valid / test, to define which split of dataset to be evaluated.
207
    :param batch_size: int, batch size to perform predictions.
208
    :param exp_name: name of the experiment.
209
    :param config_path: to overwrite the default config.
210
    :param num_workers: number of cpus to be used, -1 means not limited.
211
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
212
    :param save_nifti: if true, outputs will be saved in nifti format.
213
    :param save_png: if true, outputs will be saved in png format.
214
    :param log_dir: path of the log directory.
215
    """
216
217
    # env vars
218
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
219
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"
220
    if num_workers > 0:
221
        # Maximum number of threads to use for OpenMP parallel regions.
222
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
223
        # Without setting below 2 environment variables, it didn't work for me. Thanks to @cjw85
224
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
225
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)
226
227
    # load config
228
    config, log_dir, ckpt_path = build_config(
229
        config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path
230
    )
231
    config["train"]["preprocess"]["batch_size"] = batch_size
232
233
    # data
234
    data_loader, dataset, _ = build_dataset(
235
        dataset_config=config["dataset"],
236
        preprocess_config=config["train"]["preprocess"],
237
        mode=mode,
238
        training=False,
239
        repeat=False,
240
    )
241
    assert data_loader is not None
242
243
    # use strategy to support multiple GPUs
244
    # the network is mirrored in each GPU so that we can use larger batch size
245
    # https://www.tensorflow.org/guide/distributed_training
246
    # only model, optimizer and metrics need to be defined inside the strategy
247
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
248
    if num_devices > 1:  # pragma: no cover
249
        strategy = tf.distribute.MirroredStrategy()
250
        if batch_size % num_devices != 0:
251
            raise ValueError(
252
                f"batch size {batch_size} can not be divided evenly "
253
                f"by the number of devices."
254
            )
255
    else:
256
        strategy = tf.distribute.get_strategy()
257
    with strategy.scope():
258
        model: tf.keras.Model = REGISTRY.build_model(
259
            config=dict(
260
                name=config["train"]["method"],
261
                moving_image_size=data_loader.moving_image_shape,
262
                fixed_image_size=data_loader.fixed_image_shape,
263
                index_size=data_loader.num_indices,
264
                labeled=config["dataset"]["labeled"],
265
                batch_size=batch_size,
266
                config=config["train"],
267
            )
268
        )
269
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
270
        model.compile(optimizer=optimizer)
271
        model.plot_model(output_dir=log_dir)
272
273
    # load weights
274
    if ckpt_path.endswith(".ckpt"):
275
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
276
        # skip warnings because of optimizers
277
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
278
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
279
    else:
280
        # for ckpts from ckpt manager callback
281
        _, _ = build_checkpoint_callback(
282
            model=model,
283
            dataset=dataset,
284
            log_dir=log_dir,
285
            save_period=config["train"]["save_period"],
286
            ckpt_path=ckpt_path,
287
        )
288
289
    # predict
290
    fixed_grid_ref = tf.expand_dims(
291
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0
292
    )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
293
    predict_on_dataset(
294
        dataset=dataset,
295
        fixed_grid_ref=fixed_grid_ref,
296
        model=model,
297
        model_method=config["train"]["method"],
298
        save_dir=os.path.join(log_dir, "test"),
299
        save_nifti=save_nifti,
300
        save_png=save_png,
301
    )
302
303
    # close the opened files in data loaders
304
    data_loader.close()
305
306
307
def main(args=None):
308
    """
309
    Entry point for predict script.
310
311
    :param args:
312
    """
313
    parser = argparse.ArgumentParser()
314
315
    parser.add_argument(
316
        "--gpu",
317
        "-g",
318
        help="GPU index for training."
319
        '-g "" for using CPU'
320
        '-g "0" for using GPU 0'
321
        '-g "0,1" for using GPU 0 and 1.',
322
        type=str,
323
        required=True,
324
    )
325
326
    parser.add_argument(
327
        "--gpu_allow_growth",
328
        "-gr",
329
        help="Prevent TensorFlow from reserving all available GPU memory",
330
        default=False,
331
    )
332
333
    parser.add_argument(
334
        "--num_workers",
335
        help="Number of CPUs to be used, <= 0 means unlimited.",
336
        type=int,
337
        default=1,
338
    )
339
340
    parser.add_argument(
341
        "--ckpt_path",
342
        "-k",
343
        help="Path of checkpointed model to load",
344
        default="",
345
        type=str,
346
        required=True,
347
    )
348
349
    parser.add_argument(
350
        "--mode",
351
        "-m",
352
        help="Define the split of data to be used for prediction."
353
        "train or valid or test",
354
        type=str,
355
        default="test",
356
        required=True,
357
    )
358
359
    parser.add_argument(
360
        "--batch_size", "-b", help="Batch size for predictions", default=1, type=int
361
    )
362
363
    parser.add_argument(
364
        "--log_dir", help="Path of log directory.", default="logs", type=str
365
    )
366
367
    parser.add_argument(
368
        "--exp_name", "-n", help="Name of the experiment.", default="", type=str
369
    )
370
371
    parser.add_argument("--save_nifti", dest="nifti", action="store_true")
372
    parser.add_argument("--no_nifti", dest="nifti", action="store_false")
373
    parser.set_defaults(nifti=True)
374
375
    parser.add_argument("--save_png", dest="png", action="store_true")
376
    parser.add_argument("--no_png", dest="png", action="store_false")
377
    parser.set_defaults(png=False)
378
379
    parser.add_argument(
380
        "--config_path",
381
        "-c",
382
        help="Path of config, must end with .yaml. Can pass multiple paths.",
383
        type=str,
384
        nargs="*",
385
        default="",
386
    )
387
388
    args = parser.parse_args(args)
389
390
    predict(
391
        gpu=args.gpu,
392
        ckpt_path=args.ckpt_path,
393
        num_workers=args.num_workers,
394
        gpu_allow_growth=args.gpu_allow_growth,
395
        mode=args.mode,
396
        batch_size=args.batch_size,
397
        log_dir=args.log_dir,
398
        exp_name=args.exp_name,
399
        config_path=args.config_path,
400
        save_nifti=args.nifti,
401
        save_png=args.png,
402
    )
403
404
405
if __name__ == "__main__":
406
    main()  # pragma: no cover
407