deepreg.warp   A
last analyzed

Complexity

Total Complexity 9

Size/Duplication

Total Lines 108
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 9
eloc 50
dl 0
loc 108
rs 10
c 0
b 0
f 0

3 Functions

Rating   Name   Duplication   Size   Complexity  
A warp() 0 40 4
A main() 0 21 1
A shape_sanity_check() 0 17 4
1
# coding=utf-8
2
3
"""
4
Module to warp a image with given ddf. A CLI tool is provided.
5
"""
6
7
import argparse
8
import os
9
10
import nibabel as nib
11
import numpy as np
12
import tensorflow as tf
13
14
from deepreg import log
15
from deepreg.dataset.loader.nifti_loader import load_nifti_file
16
from deepreg.model.layer import Warping
17
18
logger = log.get(__name__)
19
20
21
def shape_sanity_check(image: np.ndarray, ddf: np.ndarray):
22
    """
23
    Verify image and ddf shapes are consistent and correct.
24
25
    :param image: a numpy array of shape (m_dim1, m_dim2, m_dim3)
26
        or (m_dim1, m_dim2, m_dim3, ch)
27
    :param ddf: a numpy array of shape (f_dim1, f_dim2, f_dim3, 3)
28
    """
29
    if len(image.shape) not in [3, 4]:
30
        raise ValueError(
31
            f"image shape must be (m_dim1, m_dim2, m_dim3) "
32
            f"or (m_dim1, m_dim2, m_dim3, ch),"
33
            f" got {image.shape}"
34
        )
35
    if not (len(ddf.shape) == 4 and ddf.shape[-1] == 3):
36
        raise ValueError(
37
            f"ddf shape must be (f_dim1, f_dim2, f_dim3, 3), got {ddf.shape}"
38
        )
39
40
41
def warp(image_path: str, ddf_path: str, out_path: str):
42
    """
43
    :param image_path: file path of the image file
44
    :param ddf_path: file path of the ddf file
45
    :param out_path: file path of the output
46
    """
47
    if out_path == "":
48
        out_path = "warped.nii.gz"
49
        logger.warning(
50
            "Output file path is not provided, will save output in %s.", out_path
51
        )
52
    else:
53
        if not (out_path.endswith(".nii") or out_path.endswith(".nii.gz")):
54
            out_path = os.path.join(os.path.dirname(out_path), "warped.nii.gz")
55
            logger.warning(
56
                "Output file path should end with .nii or .nii.gz, "
57
                "will save output in %s.",
58
                out_path,
59
            )
60
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
61
62
    # load image and ddf
63
    image = load_nifti_file(image_path)
64
    ddf = load_nifti_file(ddf_path)
65
    fixed_image_shape = ddf.shape[:3]
66
    shape_sanity_check(image=image, ddf=ddf)
67
68
    # add batch dimension manually
69
    image = tf.expand_dims(image, axis=0)
70
    ddf = tf.expand_dims(ddf, axis=0)
71
72
    # warp
73
    warped_image = Warping(fixed_image_size=fixed_image_shape)([ddf, image])
74
    warped_image = warped_image.numpy()
75
    warped_image = warped_image[0, ...]  # removed added batch dimension
76
77
    # save output
78
    nib.save(img=nib.Nifti1Image(warped_image, affine=np.eye(4)), filename=out_path)
79
80
    logger.info("Warped image has been saved at %s.", out_path)
81
82
83
def main(args=None):
84
    """
85
    Entry point for warp script.
86
87
    :param args:
88
    """
89
    parser = argparse.ArgumentParser()
90
91
    parser.add_argument(
92
        "--image", "-i", help="File path for image file", type=str, required=True
93
    )
94
95
    parser.add_argument(
96
        "--ddf", "-d", help="File path for ddf file", type=str, required=True
97
    )
98
99
    parser.add_argument("--out", "-o", help="Output path for warped image", default="")
100
101
    # init arguments
102
    args = parser.parse_args(args)
103
    warp(image_path=args.image, ddf_path=args.ddf, out_path=args.out)
104
105
106
if __name__ == "__main__":
107
    main()  # pragma: no cover
108