Passed
Pull Request — main (#724)
by Yunguan
01:33
created

deepreg.predict   A

Complexity

Total Complexity 23

Size/Duplication

Total Lines 405
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 23
eloc 232
dl 0
loc 405
rs 10
c 0
b 0
f 0

5 Functions

Rating   Name   Duplication   Size   Complexity  
A build_config() 0 34 2
A build_pair_output_path() 0 23 2
C predict_on_dataset() 0 92 11
C predict() 0 117 7
B main() 0 95 1
1
# coding=utf-8
2
3
"""
4
Module to perform predictions on data using
5
command line interface.
6
"""
7
8
import argparse
9
import logging
10
import os
11
import shutil
12
from typing import Dict, List, Tuple, Union
13
14
import numpy as np
15
import tensorflow as tf
16
17
import deepreg.config.parser as config_parser
18
import deepreg.model.layer_util as layer_util
19
import deepreg.model.optimizer as opt
20
from deepreg.callback import build_checkpoint_callback
21
from deepreg.registry import REGISTRY
22
from deepreg.util import (
23
    build_dataset,
24
    build_log_dir,
25
    calculate_metrics,
26
    save_array,
27
    save_metric_dict,
28
)
29
30
31
def build_pair_output_path(indices: list, save_dir: str) -> Tuple[str, str]:
32
    """
33
    Create directory for saving the paired data
34
35
    :param indices: indices of the pair, the last one is for label
36
    :param save_dir: directory of output
37
    :return: - save_dir, str, directory for saving the moving/fixed image
38
             - label_dir, str, directory for saving the rest outputs
39
    """
40
41
    # cast indices to string and init directory name
42
    pair_index = "pair_" + "_".join([str(x) for x in indices[:-1]])
43
    pair_dir = os.path.join(save_dir, pair_index)
44
    os.makedirs(pair_dir, exist_ok=True)
45
46
    if indices[-1] >= 0:
47
        label_index = f"label_{indices[-1]}"
48
        label_dir = os.path.join(pair_dir, label_index)
49
        os.makedirs(label_dir, exist_ok=True)
50
    else:
51
        label_dir = pair_dir
52
53
    return pair_dir, label_dir
54
55
56
def predict_on_dataset(
57
    dataset: tf.data.Dataset,
58
    fixed_grid_ref: tf.Tensor,
59
    model: tf.keras.Model,
60
    model_method: str,
61
    save_dir: str,
62
    save_nifti: bool,
63
    save_png: bool,
64
):
65
    """
66
    Function to predict results from a dataset from some model
67
68
    :param dataset: where data is stored
69
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
70
    :param model: model to be used for prediction
71
    :param model_method: ddf / dvf / affine / conditional
72
    :param save_dir: path to store dir
73
    :param save_nifti: if true, outputs will be saved in nifti format
74
    :param save_png: if true, outputs will be saved in png format
75
    """
76
    # remove the save_dir in case it exists
77
    if os.path.exists(save_dir):
78
        shutil.rmtree(save_dir)  # pragma: no cover
79
80
    sample_index_strs = []
81
    metric_lists = []
82
    for _, inputs in enumerate(dataset):
83
        batch_size = inputs[list(inputs.keys())[0]].shape[0]
84
        outputs = model.predict(x=inputs, batch_size=batch_size)
85
        indices, processed = model.postprocess(inputs=inputs, outputs=outputs)
86
87
        # convert to np arrays
88
        indices = indices.numpy()
89
        processed = {
90
            k: (v[0].numpy() if isinstance(v[0], tf.Tensor) else v[0], v[1], v[2])
91
            for k, v in processed.items()
92
        }
93
94
        # save images of inputs and outputs
95
        for sample_index in range(batch_size):
96
            # save label independent tensors under pair_dir, otherwise under label_dir
97
98
            # init output path
99
            indices_i = indices[sample_index, :].astype(int).tolist()
100
            pair_dir, label_dir = build_pair_output_path(
101
                indices=indices_i, save_dir=save_dir
102
            )
103
104
            for name, (arr, normalize, on_label) in processed.items():
105
                if name == "theta":
106
                    np.savetxt(
107
                        fname=os.path.join(pair_dir, "affine.txt"),
108
                        X=arr[sample_index, :, :],
109
                        delimiter=",",
110
                    )
111
                    continue
112
113
                arr_save_dir = label_dir if on_label else pair_dir
114
                save_array(
115
                    save_dir=arr_save_dir,
116
                    arr=arr[sample_index, :, :, :],
117
                    name=name,
118
                    normalize=normalize,  # label's value is already in [0, 1]
119
                    save_nifti=save_nifti,
120
                    save_png=save_png,
121
                    overwrite=arr_save_dir == label_dir,
122
                )
123
124
            # calculate metric
125
            sample_index_str = "_".join([str(x) for x in indices_i])
126
            if sample_index_str in sample_index_strs:  # pragma: no cover
127
                raise ValueError(
128
                    "Sample is repeated, maybe the dataset has been repeated."
129
                )
130
            sample_index_strs.append(sample_index_str)
131
132
            metric = calculate_metrics(
133
                fixed_image=processed["fixed_image"][0],
134
                fixed_label=processed["fixed_label"][0] if model.labeled else None,
135
                pred_fixed_image=processed["pred_fixed_image"][0],
136
                pred_fixed_label=processed["pred_fixed_label"][0]
137
                if model.labeled
138
                else None,
139
                fixed_grid_ref=fixed_grid_ref,
140
                sample_index=sample_index,
141
            )
142
            metric["pair_index"] = indices_i[:-1]
143
            metric["label_index"] = indices_i[-1]
144
            metric_lists.append(metric)
145
146
    # save metric
147
    save_metric_dict(save_dir=save_dir, metrics=metric_lists)
148
149
150
def build_config(
151
    config_path: Union[str, List[str]], log_dir: str, exp_name: str, ckpt_path: str
152
) -> Tuple[Dict, str, str]:
153
    """
154
    Function to create new directory to log directory to store results.
155
156
    :param config_path: path of configuration files.
157
    :param log_dir: path of the log directory.
158
    :param exp_name: experiment name.
159
    :param ckpt_path: path where model is stored.
160
    :return: - config, configuration dictionary.
161
             - exp_name, path of the directory for saving outputs.
162
    """
163
164
    # init log directory
165
    log_dir = build_log_dir(log_dir=log_dir, exp_name=exp_name)
166
167
    # replace the ~ with user home path
168
    ckpt_path = os.path.expanduser(ckpt_path)
169
170
    # load config
171
    if config_path == "":
172
        # use default config, which should be provided in the log folder
173
        config = config_parser.load_configs(
174
            "/".join(ckpt_path.split("/")[:-2]) + "/config.yaml"
175
        )
176
    else:
177
        # use customized config
178
        logging.warning(
179
            "Using customized configuration."
180
            "The code might break if the config doesn't match the saved model."
181
        )
182
        config = config_parser.load_configs(config_path)
183
    return config, log_dir, ckpt_path
184
185
186
def predict(
187
    gpu: str,
188
    ckpt_path: str,
189
    mode: str,
190
    batch_size: int,
191
    exp_name: str,
192
    config_path: Union[str, List[str]],
193
    num_workers: int = 1,
194
    gpu_allow_growth: bool = True,
195
    save_nifti: bool = True,
196
    save_png: bool = True,
197
    log_dir: str = "logs",
198
):
199
    """
200
    Function to predict some metrics from the saved model and logging results.
201
202
    :param gpu: which env gpu to use.
203
    :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x.
204
    :param mode: train / valid / test, to define which split of dataset to be evaluated.
205
    :param batch_size: int, batch size to perform predictions.
206
    :param exp_name: name of the experiment.
207
    :param config_path: to overwrite the default config.
208
    :param num_workers: number of cpus to be used, -1 means not limited.
209
    :param gpu_allow_growth: whether to allocate whole GPU memory for training.
210
    :param save_nifti: if true, outputs will be saved in nifti format.
211
    :param save_png: if true, outputs will be saved in png format.
212
    :param log_dir: path of the log directory.
213
    """
214
215
    # env vars
216
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
217
    os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true"
218
    if num_workers > 0:
219
        # Maximum number of threads to use for OpenMP parallel regions.
220
        os.environ["OMP_NUM_THREADS"] = str(num_workers)
221
        # Without setting below 2 environment variables, it didn't work for me. Thanks to @cjw85
222
        os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers)
223
        os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers)
224
225
    # load config
226
    config, log_dir, ckpt_path = build_config(
227
        config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path
228
    )
229
    config["train"]["preprocess"]["batch_size"] = batch_size
230
231
    # data
232
    data_loader, dataset, _ = build_dataset(
233
        dataset_config=config["dataset"],
234
        preprocess_config=config["train"]["preprocess"],
235
        mode=mode,
236
        training=False,
237
        repeat=False,
238
    )
239
    assert data_loader is not None
240
241
    # use strategy to support multiple GPUs
242
    # the network is mirrored in each GPU so that we can use larger batch size
243
    # https://www.tensorflow.org/guide/distributed_training
244
    # only model, optimizer and metrics need to be defined inside the strategy
245
    num_devices = max(len(tf.config.list_physical_devices("GPU")), 1)
246
    if num_devices > 1:  # pragma: no cover
247
        strategy = tf.distribute.MirroredStrategy()
248
        if batch_size % num_devices != 0:
249
            raise ValueError(
250
                f"batch size {batch_size} can not be divided evenly "
251
                f"by the number of devices."
252
            )
253
    else:
254
        strategy = tf.distribute.get_strategy()
255
    with strategy.scope():
256
        model: tf.keras.Model = REGISTRY.build_model(
257
            config=dict(
258
                name=config["train"]["method"],
259
                moving_image_size=data_loader.moving_image_shape,
260
                fixed_image_size=data_loader.fixed_image_shape,
261
                index_size=data_loader.num_indices,
262
                labeled=config["dataset"]["labeled"],
263
                batch_size=batch_size,
264
                config=config["train"],
265
            )
266
        )
267
        optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"])
268
        model.compile(optimizer=optimizer)
269
        model.plot_model(output_dir=log_dir)
270
271
    # load weights
272
    if ckpt_path.endswith(".ckpt"):
273
        # for ckpt from tf.keras.callbacks.ModelCheckpoint
274
        # skip warnings because of optimizers
275
        # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object
276
        model.load_weights(ckpt_path).expect_partial()  # pragma: no cover
277
    else:
278
        # for ckpts from ckpt manager callback
279
        _, _ = build_checkpoint_callback(
280
            model=model,
281
            dataset=dataset,
282
            log_dir=log_dir,
283
            save_period=config["train"]["save_period"],
284
            ckpt_path=ckpt_path,
285
        )
286
287
    # predict
288
    fixed_grid_ref = tf.expand_dims(
289
        layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0
290
    )  # shape = (1, f_dim1, f_dim2, f_dim3, 3)
291
    predict_on_dataset(
292
        dataset=dataset,
293
        fixed_grid_ref=fixed_grid_ref,
294
        model=model,
295
        model_method=config["train"]["method"],
296
        save_dir=os.path.join(log_dir, "test"),
297
        save_nifti=save_nifti,
298
        save_png=save_png,
299
    )
300
301
    # close the opened files in data loaders
302
    data_loader.close()
303
304
305
def main(args=None):
306
    """
307
    Entry point for predict script.
308
309
    :param args:
310
    """
311
    parser = argparse.ArgumentParser()
312
313
    parser.add_argument(
314
        "--gpu",
315
        "-g",
316
        help="GPU index for training."
317
        '-g "" for using CPU'
318
        '-g "0" for using GPU 0'
319
        '-g "0,1" for using GPU 0 and 1.',
320
        type=str,
321
        required=True,
322
    )
323
324
    parser.add_argument(
325
        "--gpu_allow_growth",
326
        "-gr",
327
        help="Prevent TensorFlow from reserving all available GPU memory",
328
        default=False,
329
    )
330
331
    parser.add_argument(
332
        "--num_workers",
333
        help="Number of CPUs to be used, <= 0 means unlimited.",
334
        type=int,
335
        default=1,
336
    )
337
338
    parser.add_argument(
339
        "--ckpt_path",
340
        "-k",
341
        help="Path of checkpointed model to load",
342
        default="",
343
        type=str,
344
        required=True,
345
    )
346
347
    parser.add_argument(
348
        "--mode",
349
        "-m",
350
        help="Define the split of data to be used for prediction."
351
        "train or valid or test",
352
        type=str,
353
        default="test",
354
        required=True,
355
    )
356
357
    parser.add_argument(
358
        "--batch_size", "-b", help="Batch size for predictions", default=1, type=int
359
    )
360
361
    parser.add_argument(
362
        "--log_dir", help="Path of log directory.", default="logs", type=str
363
    )
364
365
    parser.add_argument(
366
        "--exp_name", "-n", help="Name of the experiment.", default="", type=str
367
    )
368
369
    parser.add_argument("--save_nifti", dest="nifti", action="store_true")
370
    parser.add_argument("--no_nifti", dest="nifti", action="store_false")
371
    parser.set_defaults(nifti=True)
372
373
    parser.add_argument("--save_png", dest="png", action="store_true")
374
    parser.add_argument("--no_png", dest="png", action="store_false")
375
    parser.set_defaults(png=False)
376
377
    parser.add_argument(
378
        "--config_path",
379
        "-c",
380
        help="Path of config, must end with .yaml. Can pass multiple paths.",
381
        type=str,
382
        nargs="*",
383
        default="",
384
    )
385
386
    args = parser.parse_args(args)
387
388
    predict(
389
        gpu=args.gpu,
390
        ckpt_path=args.ckpt_path,
391
        num_workers=args.num_workers,
392
        gpu_allow_growth=args.gpu_allow_growth,
393
        mode=args.mode,
394
        batch_size=args.batch_size,
395
        log_dir=args.log_dir,
396
        exp_name=args.exp_name,
397
        config_path=args.config_path,
398
        save_nifti=args.nifti,
399
        save_png=args.png,
400
    )
401
402
403
if __name__ == "__main__":
404
    main()  # pragma: no cover
405