|
1
|
|
|
# pylint: disable=line-too-long |
|
2
|
|
|
import argparse |
|
3
|
|
|
from datetime import datetime |
|
4
|
|
|
|
|
5
|
|
|
from deepreg.predict import predict |
|
6
|
|
|
|
|
7
|
|
|
name = "unpaired_ct_abdomen" |
|
8
|
|
|
ckpt_index_dict = {"comb": 5000, "unsup": 5000, "weakly": 5000} |
|
9
|
|
|
|
|
10
|
|
|
parser = argparse.ArgumentParser() |
|
11
|
|
|
parser.add_argument( |
|
12
|
|
|
"--method", |
|
13
|
|
|
help="Training method, comb or unsup or weakly", |
|
14
|
|
|
type=str, |
|
15
|
|
|
required=True, |
|
16
|
|
|
) |
|
17
|
|
|
parser.add_argument( |
|
18
|
|
|
"--test", |
|
19
|
|
|
help="Execute the script with reduced image size for test purpose.", |
|
20
|
|
|
dest="test", |
|
21
|
|
|
action="store_true", |
|
22
|
|
|
) |
|
23
|
|
|
parser.add_argument( |
|
24
|
|
|
"--full", |
|
25
|
|
|
help="Execute the script with full configuration.", |
|
26
|
|
|
dest="test", |
|
27
|
|
|
action="store_false", |
|
28
|
|
|
) |
|
29
|
|
|
parser.set_defaults(test=False) |
|
30
|
|
|
args = parser.parse_args() |
|
31
|
|
|
method = args.method |
|
32
|
|
|
assert method in [ |
|
33
|
|
|
"comb", |
|
34
|
|
|
"unsup", |
|
35
|
|
|
"weakly", |
|
36
|
|
|
], f"method should be comb or unsup or weakly, got {method}" |
|
37
|
|
|
|
|
38
|
|
|
ckpt_index = ckpt_index_dict[method] |
|
39
|
|
|
print( |
|
40
|
|
|
"\n\n\n\n\n" |
|
41
|
|
|
"=========================================================\n" |
|
42
|
|
|
"The prediction can also be launched using the following command.\n" |
|
43
|
|
|
"deepreg_predict --gpu '' " |
|
44
|
|
|
f"--config_path demos/{name}/{name}_{method}.yaml " |
|
45
|
|
|
f"--ckpt_path demos/{name}/dataset/pretrained/{method}/ckpt-{ckpt_index} " |
|
46
|
|
|
f"--log_dir demos/{name} " |
|
47
|
|
|
f"--log_dir logs_predict/{method} " |
|
48
|
|
|
"--save_png --split test\n" |
|
49
|
|
|
"=========================================================\n" |
|
50
|
|
|
"\n\n\n\n\n" |
|
51
|
|
|
) |
|
52
|
|
|
|
|
53
|
|
|
log_dir = f"demos/{name}" |
|
54
|
|
|
exp_name = f"logs_predict/{method}/" + datetime.now().strftime("%Y%m%d-%H%M%S") |
|
55
|
|
|
ckpt_path = f"{log_dir}/dataset/pretrained/{method}/ckpt-{ckpt_index}" |
|
56
|
|
|
config_path = [f"{log_dir}/{name}_{method}.yaml"] |
|
57
|
|
|
if args.test: |
|
58
|
|
|
config_path.append("config/test/demo_unpaired_grouped.yaml") |
|
59
|
|
|
|
|
60
|
|
|
predict( |
|
61
|
|
|
gpu="0", |
|
62
|
|
|
gpu_allow_growth=True, |
|
63
|
|
|
ckpt_path=ckpt_path, |
|
64
|
|
|
split="test", |
|
65
|
|
|
batch_size=1, |
|
66
|
|
|
log_dir=log_dir, |
|
67
|
|
|
exp_name=exp_name, |
|
68
|
|
|
config_path=config_path, |
|
69
|
|
|
) |
|
70
|
|
|
|