| Conditions | 7 |
| Total Lines | 78 |
| Code Lines | 52 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
| 1 | import shutil |
||
| 9 | def test_restore_checkpoint_manager_callback(): |
||
| 10 | """ |
||
| 11 | testing restore CheckpointManagerCallback |
||
| 12 | """ |
||
| 13 | |||
| 14 | # toy model |
||
| 15 | class Net(tf.keras.Model): |
||
| 16 | """A simple linear model.""" |
||
| 17 | |||
| 18 | def __init__(self): |
||
| 19 | super().__init__() |
||
| 20 | self.l1 = tf.keras.layers.Dense(5) |
||
| 21 | |||
| 22 | def __call__(self, x, training=False): |
||
| 23 | return self.l1(x) |
||
| 24 | |||
| 25 | # toy dataset |
||
| 26 | def toy_dataset(): |
||
| 27 | inputs = tf.range(10.0)[:, None] |
||
| 28 | labels = inputs * 5.0 + tf.range(5.0)[None, :] |
||
| 29 | return tf.data.Dataset.from_tensor_slices((inputs, labels)).repeat().batch(2) |
||
| 30 | |||
| 31 | # train old_model and save |
||
| 32 | if len(tf.config.list_physical_devices("gpu")) > 1: |
||
| 33 | strategy = tf.distribute.MirroredStrategy() |
||
| 34 | else: # use default strategy |
||
| 35 | strategy = tf.distribute.get_strategy() |
||
| 36 | |||
| 37 | with strategy.scope(): |
||
| 38 | old_model = Net() |
||
| 39 | old_optimizer = tf.keras.optimizers.Adam(0.1) |
||
| 40 | old_model.compile(optimizer=old_optimizer, loss=tf.keras.losses.MSE) |
||
| 41 | old_callback, _ = build_checkpoint_callback( |
||
| 42 | model=old_model, |
||
| 43 | dataset=toy_dataset(), |
||
| 44 | log_dir="./test/unit/old", |
||
| 45 | save_period=5, |
||
| 46 | ckpt_path="", |
||
| 47 | ) |
||
| 48 | old_model.fit( |
||
| 49 | x=toy_dataset(), epochs=10, steps_per_epoch=10, callbacks=[old_callback] |
||
| 50 | ) |
||
| 51 | |||
| 52 | # create new model and restore old_model checkpoint |
||
| 53 | with strategy.scope(): |
||
| 54 | new_model = Net() |
||
| 55 | new_optimizer = tf.keras.optimizers.Adam(0.1) |
||
| 56 | new_model.compile(optimizer=new_optimizer, loss=tf.keras.losses.MSE) |
||
| 57 | new_callback, initial_epoch = build_checkpoint_callback( |
||
| 58 | model=new_model, |
||
| 59 | dataset=toy_dataset(), |
||
| 60 | log_dir="./test/unit/new", |
||
| 61 | save_period=5, |
||
| 62 | ckpt_path="./test/unit/old/save/ckpt-10", |
||
| 63 | ) |
||
| 64 | |||
| 65 | # check equal |
||
| 66 | new_callback._manager.save(0) |
||
| 67 | old_reader = tf.train.load_checkpoint("./test/unit/old/save/ckpt-10") |
||
| 68 | new_reader = tf.train.load_checkpoint("./test/unit/new/save") |
||
| 69 | for k in old_reader.get_variable_to_shape_map().keys(): |
||
| 70 | if "save_counter" not in k and "_CHECKPOINTABLE_OBJECT_GRAPH" not in k: |
||
| 71 | equal = np.array(old_reader.get_tensor(k)) == np.array( |
||
| 72 | new_reader.get_tensor(k) |
||
| 73 | ) |
||
| 74 | assert np.all(equal), "{} fail to restore !".format(k) |
||
| 75 | |||
| 76 | new_model.fit( |
||
| 77 | x=toy_dataset(), |
||
| 78 | initial_epoch=initial_epoch, |
||
| 79 | epochs=20, |
||
| 80 | steps_per_epoch=10, |
||
| 81 | callbacks=[new_callback], |
||
| 82 | ) |
||
| 83 | |||
| 84 | # remove temporary ckpt directories |
||
| 85 | shutil.rmtree("./test/unit/old") |
||
| 86 | shutil.rmtree("./test/unit/new") |
||
| 87 |