1
|
|
|
import os |
2
|
|
|
|
3
|
|
|
import numpy as np |
4
|
|
|
import pytest |
5
|
|
|
|
6
|
|
|
from deepreg.warp import main, shape_sanity_check |
7
|
|
|
|
8
|
|
|
image_path = "./data/test/nifti/unit_test/moving_image.nii.gz" |
9
|
|
|
ddf_path = "./data/test/nifti/unit_test/ddf.nii.gz" |
10
|
|
|
|
11
|
|
|
|
12
|
|
|
@pytest.mark.parametrize( |
13
|
|
|
("out_path", "expected_path"), |
14
|
|
|
[ |
15
|
|
|
("logs/test_warp/out.nii.gz", "logs/test_warp/out.nii.gz"), |
16
|
|
|
("logs/test_warp/out.h5", "logs/test_warp/warped.nii.gz"), |
17
|
|
|
("logs/test_warp/", "logs/test_warp/warped.nii.gz"), |
18
|
|
|
("", "warped.nii.gz"), |
19
|
|
|
], |
20
|
|
|
) |
21
|
|
|
def test_main(out_path: str, expected_path: str): |
22
|
|
|
main(args=["--image", image_path, "--ddf", ddf_path, "--out", out_path]) |
23
|
|
|
assert os.path.isfile(expected_path) |
24
|
|
|
os.remove(expected_path) |
25
|
|
|
|
26
|
|
|
|
27
|
|
|
class TestShapeSanityCheck: |
28
|
|
|
@pytest.mark.parametrize( |
29
|
|
|
("image_shape", "ddf_shape"), |
30
|
|
|
[ |
31
|
|
|
((2, 3, 4), (2, 3, 4, 3)), |
32
|
|
|
((2, 3, 4, 1), (2, 3, 4, 3)), |
33
|
|
|
((2, 3, 4, 3), (2, 3, 4, 3)), |
34
|
|
|
], |
35
|
|
|
) |
36
|
|
|
def test_pass(self, image_shape: tuple, ddf_shape: tuple): |
37
|
|
|
image = np.ones(image_shape) |
38
|
|
|
ddf = np.ones(ddf_shape) |
39
|
|
|
shape_sanity_check(image=image, ddf=ddf) |
40
|
|
|
|
41
|
|
|
@pytest.mark.parametrize( |
42
|
|
|
("image_shape", "ddf_shape", "err_msg"), |
43
|
|
|
[ |
44
|
|
|
( |
45
|
|
|
( |
46
|
|
|
2, |
47
|
|
|
3, |
48
|
|
|
), |
49
|
|
|
(2, 3, 4, 3), |
50
|
|
|
"image shape must be (m_dim1, m_dim2, m_dim3)", |
51
|
|
|
), |
52
|
|
|
((2, 3, 4), (2, 3, 4, 4), "ddf shape must be (f_dim1, f_dim2, f_dim3, 3)"), |
53
|
|
|
], |
54
|
|
|
) |
55
|
|
|
def test_error(self, image_shape: tuple, ddf_shape: tuple, err_msg): |
56
|
|
|
image = np.ones(image_shape) |
57
|
|
|
ddf = np.ones(ddf_shape) |
58
|
|
|
with pytest.raises(ValueError) as err_info: |
59
|
|
|
shape_sanity_check(image=image, ddf=ddf) |
60
|
|
|
assert err_msg in str(err_info.value) |
61
|
|
|
|