test_demos.remove_files()   A
last analyzed

Complexity

Conditions 5

Size

Total Lines 17
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 17
rs 9.3333
c 0
b 0
f 0
cc 5
nop 1
1
import os
2
import shutil
3
import subprocess
4
5
import pytest
6
7
from deepreg import log
8
9
logger = log.get(__name__)
10
11
12
def remove_files(name):
13
    dir_name = os.path.join("demos", name)
14
15
    # remove zip files
16
    files = os.listdir(dir_name)
17
    for file in files:
18
        if file.endswith(".zip"):
19
            os.remove(os.path.join(dir_name, file))
20
21
    # remove output folders
22
    paths = [
23
        os.path.join(dir_name, x)
24
        for x in ["dataset", "logs_train", "logs_predict", "logs_reg"]
25
    ]
26
    for path in paths:
27
        if os.path.exists(path):
28
            shutil.rmtree(path)
29
30
31
def check_files(name):
32
    """make sure dataset folder exist but there is no zip files"""
33
    dir_name = os.path.join("demos", name)
34
35
    # assert dataset folder exists
36
    assert os.path.exists(os.path.join(dir_name, "dataset"))
37
38
    # assert no zip files
39
    files = os.listdir(dir_name)
40
    files = [x for x in files if x.endswith(".zip")]
41
    assert len(files) == 0
42
43
44
def check_vis_single_config_demo(name):
45
    time_stamp = sorted(os.listdir(f"demos/{name}/logs_predict"))[0]
46
    pair_number = sorted(os.listdir(f"demos/{name}/logs_predict/{time_stamp}/test"))[-1]
47
    cmd = [
48
        f"deepreg_vis -m 2 -i 'demos/{name}/logs_predict/{time_stamp}/test/{pair_number}/moving_image.nii.gz, demos/{name}/logs_predict/{time_stamp}/test/{pair_number}/pred_fixed_image.nii.gz, demos/{name}/logs_predict/{time_stamp}/test/{pair_number}/fixed_image.nii.gz' --slice-inds '0,1,2' -s demos/{name}/logs_predict"
49
    ]
50
    execute_commands([cmd])
51
    assert os.path.exists(f"demos/{name}/logs_predict/visualisation.png")
52
53
54
def check_vis_unpaired_ct_abdomen(name, method):
55
    time_stamp = sorted(os.listdir(f"demos/{name}/logs_predict/{method}"))[0]
56
    pair_number = sorted(
57
        os.listdir(f"demos/{name}/logs_predict/{method}/{time_stamp}/test")
58
    )[-1]
59
    cmd = [
60
        f"deepreg_vis -m 2 -i 'demos/{name}/logs_predict/{method}/{time_stamp}/test/{pair_number}/moving_image.nii.gz, demos/{name}/logs_predict/{method}/{time_stamp}/test/{pair_number}/pred_fixed_image.nii.gz, demos/{name}/logs_predict/{method}/{time_stamp}/test/{pair_number}/fixed_image.nii.gz' --slice-inds '0,1,2' -s demos/{name}/logs_predict"
61
    ]
62
    execute_commands([cmd])
63
    assert os.path.exists(f"demos/{name}/logs_predict/visualisation.png")
64
65
66
def check_vis_classical_demo(name):
67
    cmd = [
68
        f"deepreg_vis -m 2 -i 'demos/{name}/logs_reg/moving_image.nii.gz, demos/{name}/logs_reg/warped_moving_image.nii.gz, demos/{name}/logs_reg/fixed_image.nii.gz' --slice-inds '0,1,2' -s demos/{name}/logs_reg"
69
    ]
70
    execute_commands([cmd])
71
    assert os.path.exists(f"demos/{name}/logs_reg/visualisation.png")
72
73
74
def execute_commands(cmds):
75
    for cmd in cmds:
76
        try:
77
            logger.info(f"Running {cmd}")
78
            out = subprocess.check_output(cmd, shell=True).decode("utf-8")
79
            logger.info(out)
80
        except subprocess.CalledProcessError as e:
81
            raise RuntimeError(
82
                f"Command {cmd} return with err {e.returncode} {e.output}"
83
            )
84
85
86
class TestDemo:
87
    @pytest.mark.parametrize(
88
        "name",
89
        [
90
            "grouped_mask_prostate_longitudinal",
91
            "grouped_mr_heart",
92
            "paired_ct_lung",
93
            "paired_mrus_brain",
94
            "paired_mrus_prostate",
95
            "unpaired_ct_lung",
96
            "unpaired_mr_brain",
97
            "unpaired_us_prostate_cv",
98
        ],
99
    )
100
    def test_single_config_demo(self, name):
101
        """each demo has one single configuration file"""
102
        remove_files(name)
103
104
        # execute data
105
        cmds = [f"python demos/{name}/demo_data.py"]
106
        execute_commands(cmds)
107
108
        # check temporary files are removed
109
        check_files(name)
110
111
        # execute train, predict sequentially
112
        cmds = [f"python demos/{name}/demo_{x}.py --test" for x in ["train", "predict"]]
113
114
        execute_commands(cmds)
115
        check_vis_single_config_demo(name)
116
117
    @pytest.mark.parametrize(
118
        "method",
119
        ["comb", "unsup", "weakly"],
120
    )
121
    def test_unpaired_ct_abdomen(self, method):
122
        """this demo has multiple configuration file"""
123
        name = "unpaired_ct_abdomen"
124
        remove_files(name)
125
126
        # execute data
127
        cmds = [f"python demos/{name}/demo_data.py"]
128
        execute_commands(cmds)
129
130
        # check temporary files are removed
131
        check_files(name)
132
133
        # execute train, predict sequentially
134
        cmds = [
135
            f"python demos/{name}/demo_{x}.py --method {method} --test"
136
            for x in ["train", "predict"]
137
        ]
138
139
        execute_commands(cmds)
140
        check_vis_unpaired_ct_abdomen(name, method)
141
142
    @pytest.mark.parametrize(
143
        "name",
144
        [
145
            "classical_ct_headneck_affine",
146
            "classical_mr_prostate_nonrigid",
147
        ],
148
    )
149
    def test_classical_demo(self, name):
150
        remove_files(name)
151
152
        # execute data
153
        cmds = [f"python demos/{name}/demo_data.py"]
154
        execute_commands(cmds)
155
156
        # check temporary files are removed
157
        check_files(name)
158
159
        # execute data, register
160
        cmds = [f"python demos/{name}/demo_register.py --test"]
161
162
        execute_commands(cmds)
163
        check_vis_classical_demo(name)
164