test.unit.test_callback   A
last analyzed

Complexity

Total Complexity 7

Size/Duplication

Total Lines 87
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 7
eloc 57
dl 0
loc 87
rs 10
c 0
b 0
f 0

1 Function

Rating   Name   Duplication   Size   Complexity  
B test_restore_checkpoint_manager_callback() 0 78 7
1
import shutil
2
3
import numpy as np
4
import tensorflow as tf
5
6
from deepreg.callback import build_checkpoint_callback
7
8
9
def test_restore_checkpoint_manager_callback():
10
    """
11
    testing restore CheckpointManagerCallback
12
    """
13
14
    # toy model
15
    class Net(tf.keras.Model):
16
        """A simple linear model."""
17
18
        def __init__(self):
19
            super().__init__()
20
            self.l1 = tf.keras.layers.Dense(5)
21
22
        def __call__(self, x, training=False):
23
            return self.l1(x)
24
25
    # toy dataset
26
    def toy_dataset():
27
        inputs = tf.range(10.0)[:, None]
28
        labels = inputs * 5.0 + tf.range(5.0)[None, :]
29
        return tf.data.Dataset.from_tensor_slices((inputs, labels)).repeat().batch(2)
30
31
    # train old_model and save
32
    if len(tf.config.list_physical_devices("gpu")) > 1:
33
        strategy = tf.distribute.MirroredStrategy()
34
    else:  # use default strategy
35
        strategy = tf.distribute.get_strategy()
36
37
    with strategy.scope():
38
        old_model = Net()
39
        old_optimizer = tf.keras.optimizers.Adam(0.1)
40
    old_model.compile(optimizer=old_optimizer, loss=tf.keras.losses.MSE)
41
    old_callback, _ = build_checkpoint_callback(
42
        model=old_model,
43
        dataset=toy_dataset(),
44
        log_dir="./test/unit/old",
45
        save_period=5,
46
        ckpt_path="",
47
    )
48
    old_model.fit(
49
        x=toy_dataset(), epochs=10, steps_per_epoch=10, callbacks=[old_callback]
50
    )
51
52
    # create new model and restore old_model checkpoint
53
    with strategy.scope():
54
        new_model = Net()
55
        new_optimizer = tf.keras.optimizers.Adam(0.1)
56
    new_model.compile(optimizer=new_optimizer, loss=tf.keras.losses.MSE)
57
    new_callback, initial_epoch = build_checkpoint_callback(
58
        model=new_model,
59
        dataset=toy_dataset(),
60
        log_dir="./test/unit/new",
61
        save_period=5,
62
        ckpt_path="./test/unit/old/save/ckpt-10",
63
    )
64
65
    # check equal
66
    new_callback._manager.save(0)
67
    old_reader = tf.train.load_checkpoint("./test/unit/old/save/ckpt-10")
68
    new_reader = tf.train.load_checkpoint("./test/unit/new/save")
69
    for k in old_reader.get_variable_to_shape_map().keys():
70
        if "save_counter" not in k and "_CHECKPOINTABLE_OBJECT_GRAPH" not in k:
71
            equal = np.array(old_reader.get_tensor(k)) == np.array(
72
                new_reader.get_tensor(k)
73
            )
74
            assert np.all(equal), "{} fail to restore !".format(k)
75
76
    new_model.fit(
77
        x=toy_dataset(),
78
        initial_epoch=initial_epoch,
79
        epochs=20,
80
        steps_per_epoch=10,
81
        callbacks=[new_callback],
82
    )
83
84
    # remove temporary ckpt directories
85
    shutil.rmtree("./test/unit/old")
86
    shutil.rmtree("./test/unit/new")
87