|
1
|
|
|
import shutil |
|
2
|
|
|
|
|
3
|
|
|
import numpy as np |
|
4
|
|
|
import tensorflow as tf |
|
5
|
|
|
|
|
6
|
|
|
from deepreg.callback import build_checkpoint_callback |
|
7
|
|
|
|
|
8
|
|
|
|
|
9
|
|
|
def test_restore_checkpoint_manager_callback(): |
|
10
|
|
|
""" |
|
11
|
|
|
testing restore CheckpointManagerCallback |
|
12
|
|
|
""" |
|
13
|
|
|
|
|
14
|
|
|
# toy model |
|
15
|
|
|
class Net(tf.keras.Model): |
|
16
|
|
|
"""A simple linear model.""" |
|
17
|
|
|
|
|
18
|
|
|
def __init__(self): |
|
19
|
|
|
super().__init__() |
|
20
|
|
|
self.l1 = tf.keras.layers.Dense(5) |
|
21
|
|
|
|
|
22
|
|
|
def __call__(self, x, training=False): |
|
23
|
|
|
return self.l1(x) |
|
24
|
|
|
|
|
25
|
|
|
# toy dataset |
|
26
|
|
|
def toy_dataset(): |
|
27
|
|
|
inputs = tf.range(10.0)[:, None] |
|
28
|
|
|
labels = inputs * 5.0 + tf.range(5.0)[None, :] |
|
29
|
|
|
return tf.data.Dataset.from_tensor_slices((inputs, labels)).repeat().batch(2) |
|
30
|
|
|
|
|
31
|
|
|
# train old_model and save |
|
32
|
|
|
if len(tf.config.list_physical_devices("gpu")) > 1: |
|
33
|
|
|
strategy = tf.distribute.MirroredStrategy() |
|
34
|
|
|
else: # use default strategy |
|
35
|
|
|
strategy = tf.distribute.get_strategy() |
|
36
|
|
|
|
|
37
|
|
|
with strategy.scope(): |
|
38
|
|
|
old_model = Net() |
|
39
|
|
|
old_optimizer = tf.keras.optimizers.Adam(0.1) |
|
40
|
|
|
old_model.compile(optimizer=old_optimizer, loss=tf.keras.losses.MSE) |
|
41
|
|
|
old_callback, _ = build_checkpoint_callback( |
|
42
|
|
|
model=old_model, |
|
43
|
|
|
dataset=toy_dataset(), |
|
44
|
|
|
log_dir="./test/unit/old", |
|
45
|
|
|
save_period=5, |
|
46
|
|
|
ckpt_path="", |
|
47
|
|
|
) |
|
48
|
|
|
old_model.fit( |
|
49
|
|
|
x=toy_dataset(), epochs=10, steps_per_epoch=10, callbacks=[old_callback] |
|
50
|
|
|
) |
|
51
|
|
|
|
|
52
|
|
|
# create new model and restore old_model checkpoint |
|
53
|
|
|
with strategy.scope(): |
|
54
|
|
|
new_model = Net() |
|
55
|
|
|
new_optimizer = tf.keras.optimizers.Adam(0.1) |
|
56
|
|
|
new_model.compile(optimizer=new_optimizer, loss=tf.keras.losses.MSE) |
|
57
|
|
|
new_callback, initial_epoch = build_checkpoint_callback( |
|
58
|
|
|
model=new_model, |
|
59
|
|
|
dataset=toy_dataset(), |
|
60
|
|
|
log_dir="./test/unit/new", |
|
61
|
|
|
save_period=5, |
|
62
|
|
|
ckpt_path="./test/unit/old/save/ckpt-10", |
|
63
|
|
|
) |
|
64
|
|
|
|
|
65
|
|
|
# check equal |
|
66
|
|
|
new_callback._manager.save(0) |
|
67
|
|
|
old_reader = tf.train.load_checkpoint("./test/unit/old/save/ckpt-10") |
|
68
|
|
|
new_reader = tf.train.load_checkpoint("./test/unit/new/save") |
|
69
|
|
|
for k in old_reader.get_variable_to_shape_map().keys(): |
|
70
|
|
|
if "save_counter" not in k and "_CHECKPOINTABLE_OBJECT_GRAPH" not in k: |
|
71
|
|
|
equal = np.array(old_reader.get_tensor(k)) == np.array( |
|
72
|
|
|
new_reader.get_tensor(k) |
|
73
|
|
|
) |
|
74
|
|
|
assert np.all(equal), "{} fail to restore !".format(k) |
|
75
|
|
|
|
|
76
|
|
|
new_model.fit( |
|
77
|
|
|
x=toy_dataset(), |
|
78
|
|
|
initial_epoch=initial_epoch, |
|
79
|
|
|
epochs=20, |
|
80
|
|
|
steps_per_epoch=10, |
|
81
|
|
|
callbacks=[new_callback], |
|
82
|
|
|
) |
|
83
|
|
|
|
|
84
|
|
|
# remove temporary ckpt directories |
|
85
|
|
|
shutil.rmtree("./test/unit/old") |
|
86
|
|
|
shutil.rmtree("./test/unit/new") |
|
87
|
|
|
|