Completed
Push — main ( 0c57ec...f6b5bf )
by Yunguan
18s queued 13s
created

unpaired_ct_abdomen.demo_predict   A

Complexity

Total Complexity 0

Size/Duplication

Total Lines 70
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 0
eloc 49
dl 0
loc 70
rs 10
c 0
b 0
f 0
1
# pylint: disable=line-too-long
2
import argparse
3
from datetime import datetime
4
5
from deepreg.predict import predict
6
7
name = "unpaired_ct_abdomen"
8
ckpt_index_dict = {"comb": 2000, "unsup": 5000, "weakly": 2250}
9
10
parser = argparse.ArgumentParser()
11
parser.add_argument(
12
    "--method",
13
    help="Training method, comb or unsup or weakly",
14
    type=str,
15
    required=True,
16
)
17
parser.add_argument(
18
    "--test",
19
    help="Execute the script for test purpose",
20
    dest="test",
21
    action="store_true",
22
)
23
parser.add_argument(
24
    "--no-test",
25
    help="Execute the script for non-test purpose",
26
    dest="test",
27
    action="store_false",
28
)
29
parser.set_defaults(test=False)
30
args = parser.parse_args()
31
method = args.method
32
assert method in [
33
    "comb",
34
    "unsup",
35
    "weakly",
36
], f"method should be comb or unsup or weakly, got {method}"
37
38
ckpt_index = ckpt_index_dict[method]
39
print(
40
    "\n\n\n\n\n"
41
    "=========================================================\n"
42
    "The prediction can also be launched using the following command.\n"
43
    "deepreg_predict --gpu '' "
44
    f"--config_path demos/{name}/{name}_{method}.yaml "
45
    f"--ckpt_path demos/{name}/dataset/pretrained/{method}/weights-epoch{ckpt_index}.ckpt "
46
    f"--log_root demos/{name} "
47
    f"--log_dir logs_predict/{method} "
48
    "--save_png --mode test\n"
49
    "=========================================================\n"
50
    "\n\n\n\n\n"
51
)
52
53
log_root = f"demos/{name}"
54
log_dir = f"logs_predict/{method}/" + datetime.now().strftime("%Y%m%d-%H%M%S")
55
ckpt_path = f"{log_root}/dataset/pretrained/{method}/weights-epoch{ckpt_index}.ckpt"
56
config_path = [f"{log_root}/{name}_{method}.yaml"]
57
if args.test:
58
    config_path.append("config/test/demo_unpaired_grouped.yaml")
59
60
predict(
61
    gpu="0",
62
    gpu_allow_growth=True,
63
    ckpt_path=ckpt_path,
64
    mode="test",
65
    batch_size=1,
66
    log_root=log_root,
67
    log_dir=log_dir,
68
    sample_label="all",
69
    config_path=config_path,
70
)
71