Passed
Pull Request — main (#724)
by Yunguan
01:24
created

deepreg.predict   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 407
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 23
eloc 233
dl 0
loc 407
rs 10
c 0
b 0
f 0

5 Functions

Rating   Name   Duplication   Size   Complexity  
A build_config() 0 34 2
A build_pair_output_path() 0 23 2
C predict_on_dataset() 0 92 11
C predict() 0 117 7
B main() 0 95 1
1
# coding=utf-8
2
3
"""
4
Module to perform predictions on data using
5
command line interface.
6
"""
7
8
import argparse
9
import os
10
import shutil
11
from typing import Dict, List, Tuple, Union
12
13
import numpy as np
14
import tensorflow as tf
15
16
import deepreg.config.parser as config_parser
17
import deepreg.model.layer_util as layer_util
18
import deepreg.model.optimizer as opt
19
from deepreg import log
20
from deepreg.callback import build_checkpoint_callback
21
from deepreg.registry import REGISTRY
22
from deepreg.util import (
23
    build_dataset,
24
    build_log_dir,
25
    calculate_metrics,
26
    save_array,
27
    save_metric_dict,
28
)
29
30
logger = log.get(__name__)
31
32
33
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]:
34
    """
35
    Create directory for saving the paired data
36
37
    :param indices: indices of the pair, the last one is for label
38
    :param save_dir: directory of output
39
    :return: - save_dir, str, directory for saving the moving/fixed image
40
             - label_dir, str, directory for saving the rest outputs
41
    """
42
43
    # cast indices to string and init directory name
44
    pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]])
45
    pair_dir = os.path.join(save_dir, pair_index)
46
    os.makedirs(pair_dir, exist_ok=True)
47
48
    if indices[-1] >= 0:
49
        label_index = f"label_{indices[-1]}"
50
        label_dir = os.path.join(pair_dir, label_index)
51
        os.makedirs(label_dir, exist_ok=True)
52
    else:
53
        label_dir = pair_dir
54
55
    return pair_dir, label_dir
56
57
58
def predict_on_dataset(
59
    dataset: tf.data.Dataset,
60
    fixed_grid_ref: tf.Tensor,
61
    model: tf.keras.Model,
62
    model_method: str,
63
    save_dir: str,
64
    save_nifti: bool,
65
    save_png: bool,
66
):
67
    """
68
    Function to predict results from a dataset from some model
69
70
    :param dataset: where data is stored
71
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
72
    :param model: model to be used for prediction
73
    :param model_method: ddf / dvf / affine / conditional
74
    :param save_dir: path to store dir
75
    :param save_nifti: if true, outputs will be saved in nifti format
76
    :param save_png: if true, outputs will be saved in png format
77
    """
78
    # remove the save_dir in case it exists
79
    if os.path.exists(save_dir):
80
        shutil.rmtree(save_dir)  # pragma: no cover
81
82
    sample_index_strs = []
83
    metric_lists = []
84
    for _, inputs in enumerate(dataset):
85
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
86
        outputs = model.predict(x=inputs, batch_size=batch_size)
87
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)
88
89
        # convert to np arrays
90
        indices = indices.numpy()
91
        processed = {
92
            k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
93
            for k, v in processed.items()
94
        }
95
96
        # save images of inputs and outputs
97
        for sample_index in range(batch_size):
98
            # save label independent tensors under pair_dir, otherwise under label_dir
99
100
            # init output path
101
            indices_i = indices[sample_index, :].astype(int).tolist()
102
            pair_dir, label_dir = build_pair_output_path(
103
                indices=indices_i, save_dir=save_dir
104
            )
105
106
            for name, (arr, normalize, on_label) in processed.items():
107
                if name == "theta":
108
                    np.savetxt(
109
                        fname=os.path.join(pair_dir, "affine.txt"),
110
                        X=arr[sample_index, :, :],
111
                        delimiter=",",
112
                    )
113
                    continue
114
115
                arr_save_dir = label_dir if on_label else pair_dir
116
                save_array(
117
                    save_dir=arr_save_dir,
118
                    arr=arr[sample_index, :, :, :],
119
                    name=name,
120
                    normalize=normalize,  # label's value is already in [0, 1]
121
                    save_nifti=save_nifti,
122
                    save_png=save_png,
123
                    overwrite=arr_save_dir == label_dir,
124
                )
125
126
            # calculate metric
127
            sample_index_str = "_".join([str(x) for x in indices_i])
128
            if sample_index_str in sample_index_strs:  # pragma: no cover
129
                raise ValueError(
130
                    "Sample is repeated, maybe the dataset has been repeated."
131
                )
132
            sample_index_strs.append(sample_index_str)
133
134
            metric = calculate_metrics(
135
                fixed_image=processed["fixed_image"][0],
136
                fixed_label=processed["fixed_label"][0] if model.labeled else None,
137
                pred_fixed_image=processed["pred_fixed_image"][0],
138
                pred_fixed_label=processed["pred_fixed_label"][0]
139
                if model.labeled
140
                else None,
141
                fixed_grid_ref=fixed_grid_ref,
142
                sample_index=sample_index,
143
            )
144
            metric["pair_index"] = indices_i[:-1]
145
            metric["label_index"] = indices_i[-1]
146
            metric_lists.append(metric)
147
148
    # save metric
149
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
150
151
152
def build_config(
153
    config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str
154
) -> Tuple[Dict, str, str]:
155
    """
156
    Function to create new directory to log directory to store results.
157
158
    :param config_path: path of configuration files.
159
    :param log_dir: path of the log directory.
160
    :param exp_name: experiment name.
161
    :param ckpt_path: path where model is stored.
162
    :return: - config, configuration dictionary.
163
             - exp_name, path of the directory for saving outputs.
164
    """
165
166
    # init log directory
167
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
168
169
    # replace the ~ with user home path
170
    ckpt_path = os.path.expanduser(ckpt_path)
171
172
    # load config
173
    if config_path == "":
174
        # use default config, which should be provided in the log folder
175
        config = config_parser.load_configs(
176
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml"
177
        )
178
    else:
179
        # use customized config
180
        logger.warning(
181
            "Using customized configuration. "
182
            "The code might break if the config doesn't match the saved model."
183
        )
184
        config = config_parser.load_configs(config_path)
185
    return config, log_dir, ckpt_path
186
187
188
def predict(
189
    gpu: str,
190
    ckpt_path: str,
191
    mode: str,
192
    batch_size: int,
193
    exp_name: str,
194
    config_path: Union[str, List[str]],
195
    num_workers: int = 1,
196
    gpu_allow_growth: bool = True,
197
    save_nifti: bool = True,
198
    save_png: bool = True,
199
    log_dir: str = "logs",
200
):
201
    """
202
    Function to predict some metrics from the saved model and logging results.
203
204
    :param gpu: which env gpu to use.
205
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
206
    :param mode: train / valid / test, to define which split of dataset to be evaluated.
207
    :param batch_size: int, batch size to perform predictions.
208
    :param exp_name: name of the experiment.
209
    :param config_path: to overwrite the default config.
210
    :param num_workers: number of cpus to be used, -1 means not limited.
211
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
212
    :param save_nifti: if true, outputs will be saved in nifti format.
213
    :param save_png: if true, outputs will be saved in png format.
214
    :param log_dir: path of the log directory.
215
    """
216
217
    # env vars
218
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
219
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"
220
    if num_workers > 0:
221
        # Maximum number of threads to use for OpenMP parallel regions.
222
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
223
        # Without setting below 2 environment variables, it didn't work for me. Thanks to @cjw85
224
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
225
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)
226
227
    # load config
228
    config, log_dir, ckpt_path = build_config(
229
        config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path
230
    )
231
    config["train"]["preprocess"]["batch_size"] = batch_size
232
233
    # data
234
    data_loader, dataset, _ = build_dataset(
235
        dataset_config=config["dataset"],
236
        preprocess_config=config["train"]["preprocess"],
237
        mode=mode,
238
        training=False,
239
        repeat=False,
240
    )
241
    assert data_loader is not None
242
243
    # use strategy to support multiple GPUs
244
    # the network is mirrored in each GPU so that we can use larger batch size
245
    # https://www.tensorflow.org/guide/distributed_training
246
    # only model, optimizer and metrics need to be defined inside the strategy
247
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
248
    if num_devices > 1:  # pragma: no cover
249
        strategy = tf.distribute.MirroredStrategy()
250
        if batch_size % num_devices != 0:
251
            raise ValueError(
252
                f"batch size {batch_size} can not be divided evenly "
253
                f"by the number of devices."
254
            )
255
    else:
256
        strategy = tf.distribute.get_strategy()
257
    with strategy.scope():
258
        model: tf.keras.Model = REGISTRY.build_model(
259
            config=dict(
260
                name=config["train"]["method"],
261
                moving_image_size=data_loader.moving_image_shape,
262
                fixed_image_size=data_loader.fixed_image_shape,
263
                index_size=data_loader.num_indices,
264
                labeled=config["dataset"]["labeled"],
265
                batch_size=batch_size,
266
                config=config["train"],
267
            )
268
        )
269
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
270
        model.compile(optimizer=optimizer)
271
        model.plot_model(output_dir=log_dir)
272
273
    # load weights
274
    if ckpt_path.endswith(".ckpt"):
275
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
276
        # skip warnings because of optimizers
277
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
278
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
279
    else:
280
        # for ckpts from ckpt manager callback
281
        _, _ = build_checkpoint_callback(
282
            model=model,
283
            dataset=dataset,
284
            log_dir=log_dir,
285
            save_period=config["train"]["save_period"],
286
            ckpt_path=ckpt_path,
287
        )
288
289
    # predict
290
    fixed_grid_ref = tf.expand_dims(
291
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0
292
    )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
293
    predict_on_dataset(
294
        dataset=dataset,
295
        fixed_grid_ref=fixed_grid_ref,
296
        model=model,
297
        model_method=config["train"]["method"],
298
        save_dir=os.path.join(log_dir, "test"),
299
        save_nifti=save_nifti,
300
        save_png=save_png,
301
    )
302
303
    # close the opened files in data loaders
304
    data_loader.close()
305
306
307
def main(args=None):
308
    """
309
    Entry point for predict script.
310
311
    :param args:
312
    """
313
    parser = argparse.ArgumentParser()
314
315
    parser.add_argument(
316
        "--gpu",
317
        "-g",
318
        help="GPU index for training."
319
        '-g "" for using CPU'
320
        '-g "0" for using GPU 0'
321
        '-g "0,1" for using GPU 0 and 1.',
322
        type=str,
323
        required=True,
324
    )
325
326
    parser.add_argument(
327
        "--gpu_allow_growth",
328
        "-gr",
329
        help="Prevent TensorFlow from reserving all available GPU memory",
330
        default=False,
331
    )
332
333
    parser.add_argument(
334
        "--num_workers",
335
        help="Number of CPUs to be used, <= 0 means unlimited.",
336
        type=int,
337
        default=1,
338
    )
339
340
    parser.add_argument(
341
        "--ckpt_path",
342
        "-k",
343
        help="Path of checkpointed model to load",
344
        default="",
345
        type=str,
346
        required=True,
347
    )
348
349
    parser.add_argument(
350
        "--mode",
351
        "-m",
352
        help="Define the split of data to be used for prediction."
353
        "train or valid or test",
354
        type=str,
355
        default="test",
356
        required=True,
357
    )
358
359
    parser.add_argument(
360
        "--batch_size", "-b", help="Batch size for predictions", default=1, type=int
361
    )
362
363
    parser.add_argument(
364
        "--log_dir", help="Path of log directory.", default="logs", type=str
365
    )
366
367
    parser.add_argument(
368
        "--exp_name", "-n", help="Name of the experiment.", default="", type=str
369
    )
370
371
    parser.add_argument("--save_nifti", dest="nifti", action="store_true")
372
    parser.add_argument("--no_nifti", dest="nifti", action="store_false")
373
    parser.set_defaults(nifti=True)
374
375
    parser.add_argument("--save_png", dest="png", action="store_true")
376
    parser.add_argument("--no_png", dest="png", action="store_false")
377
    parser.set_defaults(png=False)
378
379
    parser.add_argument(
380
        "--config_path",
381
        "-c",
382
        help="Path of config, must end with .yaml. Can pass multiple paths.",
383
        type=str,
384
        nargs="*",
385
        default="",
386
    )
387
388
    args = parser.parse_args(args)
389
390
    predict(
391
        gpu=args.gpu,
392
        ckpt_path=args.ckpt_path,
393
        num_workers=args.num_workers,
394
        gpu_allow_growth=args.gpu_allow_growth,
395
        mode=args.mode,
396
        batch_size=args.batch_size,
397
        log_dir=args.log_dir,
398
        exp_name=args.exp_name,
399
        config_path=args.config_path,
400
        save_nifti=args.nifti,
401
        save_png=args.png,
402
    )
403
404
405
if __name__ == "__main__":
406
    main()  # pragma: no cover
407