test_build_pair_output_path()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 25
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 16
dl 0
loc 25
rs 9.6
c 0
b 0
f 0
cc 1
nop 0
1
# coding=utf-8
2
3
"""
4
Tests for deepreg/train.py
5
pytest style
6
"""
7
8
import os
9
import shutil
10
11
from deepreg.predict import build_config, build_pair_output_path
12
13
14
def test_build_pair_output_path():
15
    """
16
    Test build_config for labeled and unlabeled cases
17
    """
18
19
    save_dir = "logs/save_dir_example"
20
21
    # labeled
22
    got = build_pair_output_path(indices=[1, 2, 0], save_dir=save_dir)
23
    expected = (
24
        "logs/save_dir_example/pair_1_2",
25
        "logs/save_dir_example/pair_1_2/label_0",
26
    )
27
    assert got == expected
28
    assert os.path.exists(got[0])
29
    assert os.path.exists(got[1])
30
    shutil.rmtree(got[0])
31
32
    # unlabeled
33
    got = build_pair_output_path(indices=[1, 2, -1], save_dir=save_dir)
34
    expected = ("logs/save_dir_example/pair_1_2", "logs/save_dir_example/pair_1_2")
35
    assert got == expected
36
    assert os.path.exists(got[0])
37
    assert os.path.exists(got[1])
38
    shutil.rmtree(got[0])
39
40
41
def test_build_config():
42
    """
43
    Test build_config and check exp_name setting and checkpoint path verification
44
    """
45
    config_path = "config/unpaired_labeled_ddf.yaml"
46
    exp_name = "test_build_config"
47
    log_dir = "logs"
48
49
    # TODO checkpoint path empty
50
51
    # checkpoint path ends with ckpt
52
    got_config, got_log_dir, _ = build_config(
53
        config_path=config_path,
54
        log_dir=log_dir,
55
        exp_name=exp_name,
56
        ckpt_path="example.ckpt",
57
    )
58
    assert isinstance(got_config, dict)
59
    assert got_log_dir == os.path.join(log_dir, exp_name)
60
61
62
def test_predict_on_dataset():
63
    # predict_on_dataset is tested in test_train/test_train_and_predict
64
    pass
65