1
|
|
|
# pylint: disable=line-too-long |
2
|
|
|
import argparse |
3
|
|
|
from datetime import datetime |
4
|
|
|
|
5
|
|
|
from deepreg.predict import predict |
6
|
|
|
|
7
|
|
|
name = "unpaired_ct_abdomen" |
8
|
|
|
ckpt_index_dict = {"comb": 2000, "unsup": 5000, "weakly": 2250} |
9
|
|
|
|
10
|
|
|
parser = argparse.ArgumentParser() |
11
|
|
|
parser.add_argument( |
12
|
|
|
"--method", |
13
|
|
|
help="Training method, comb or unsup or weakly", |
14
|
|
|
type=str, |
15
|
|
|
required=True, |
16
|
|
|
) |
17
|
|
|
parser.add_argument( |
18
|
|
|
"--test", |
19
|
|
|
help="Execute the script for test purpose", |
20
|
|
|
dest="test", |
21
|
|
|
action="store_true", |
22
|
|
|
) |
23
|
|
|
parser.add_argument( |
24
|
|
|
"--no-test", |
25
|
|
|
help="Execute the script for non-test purpose", |
26
|
|
|
dest="test", |
27
|
|
|
action="store_false", |
28
|
|
|
) |
29
|
|
|
parser.set_defaults(test=False) |
30
|
|
|
args = parser.parse_args() |
31
|
|
|
method = args.method |
32
|
|
|
assert method in [ |
33
|
|
|
"comb", |
34
|
|
|
"unsup", |
35
|
|
|
"weakly", |
36
|
|
|
], f"method should be comb or unsup or weakly, got {method}" |
37
|
|
|
|
38
|
|
|
ckpt_index = ckpt_index_dict[method] |
39
|
|
|
print( |
40
|
|
|
"\n\n\n\n\n" |
41
|
|
|
"=========================================================\n" |
42
|
|
|
"The prediction can also be launched using the following command.\n" |
43
|
|
|
"deepreg_predict --gpu '' " |
44
|
|
|
f"--config_path demos/{name}/{name}_{method}.yaml " |
45
|
|
|
f"--ckpt_path demos/{name}/dataset/pretrained/{method}/weights-epoch{ckpt_index}.ckpt " |
46
|
|
|
f"--log_root demos/{name} " |
47
|
|
|
f"--log_dir logs_predict/{method} " |
48
|
|
|
"--save_png --mode test\n" |
49
|
|
|
"=========================================================\n" |
50
|
|
|
"\n\n\n\n\n" |
51
|
|
|
) |
52
|
|
|
|
53
|
|
|
log_root = f"demos/{name}" |
54
|
|
|
log_dir = f"logs_predict/{method}/" + datetime.now().strftime("%Y%m%d-%H%M%S") |
55
|
|
|
ckpt_path = f"{log_root}/dataset/pretrained/{method}/weights-epoch{ckpt_index}.ckpt" |
56
|
|
|
config_path = [f"{log_root}/{name}_{method}.yaml"] |
57
|
|
|
if args.test: |
58
|
|
|
config_path.append("config/test/demo_unpaired_grouped.yaml") |
59
|
|
|
|
60
|
|
|
predict( |
61
|
|
|
gpu="0", |
62
|
|
|
gpu_allow_growth=True, |
63
|
|
|
ckpt_path=ckpt_path, |
64
|
|
|
mode="test", |
65
|
|
|
batch_size=1, |
66
|
|
|
log_root=log_root, |
67
|
|
|
log_dir=log_dir, |
68
|
|
|
sample_label="all", |
69
|
|
|
config_path=config_path, |
70
|
|
|
) |
71
|
|
|
|