deepreg.callback   A
last analyzed

Complexity

Total Complexity 14

Size/Duplication

Total Lines 101
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 14
eloc 60
dl 0
loc 101
rs 10
c 0
b 0
f 0

7 Methods

Rating   Name   Duplication   Size   Complexity  
A CheckpointManagerCallback.__init__() 0 23 1
A CheckpointManagerCallback._on_begin() 0 3 2
A CheckpointManagerCallback.on_train_begin() 0 2 1
A CheckpointManagerCallback.restore() 0 5 2
A CheckpointManagerCallback.on_epoch_end() 0 5 2
A CheckpointManagerCallback._save() 0 7 2
A CheckpointManagerCallback.on_train_end() 0 3 2

1 Function

Rating   Name   Duplication   Size   Complexity  
A build_checkpoint_callback() 0 38 2
1
from typing import Tuple
2
3
import tensorflow as tf
4
5
6
class CheckpointManagerCallback(tf.keras.callbacks.Callback):
7
    def __init__(
8
        self, model, directory, period: int = 1, save_on_train_end: bool = True
9
    ):
10
        """
11
        Callback wrapping `tf.train.CheckpointManager`.
12
13
        :param model: model
14
        :param directory: directory to store the checkpoints
15
        :param period: save the checkpoint every X epochs
16
        :param save_on_train_end: save the checkpoint as the training ends
17
        """
18
        super().__init__()
19
        self._directory = directory
20
21
        self._checkpoint = tf.train.Checkpoint(model=model, optimizer=model.optimizer)
22
        self._manager = tf.train.CheckpointManager(
23
            checkpoint=self._checkpoint, directory=self._directory, max_to_keep=None
24
        )
25
        self._period = period
26
        self._save_on_train_end = save_on_train_end
27
        self._restored = False
28
        self._epoch_count = None
29
        self._last_save = None
30
31
    def _on_begin(self):
32
        if not self._restored:
33
            self.restore()
34
35
    def restore(self, save_path=None):
36
        if save_path is None:
37
            save_path = self._manager.latest_checkpoint
38
        self._checkpoint.restore(save_path)
39
        self._restored = True
40
41
    def on_train_begin(self, logs=None):
42
        self._on_begin()
43
44
    def on_epoch_end(self, epoch, logs=None):
45
        epochs_finished = epoch + 1
46
        self._epoch_count = epochs_finished
47
        if epochs_finished % self._period == 0:
48
            self._save()
49
50
    def on_train_end(self, logs=None):
51
        if self._save_on_train_end:
52
            self._save()
53
54
    def _save(self):
55
        """
56
        checkpoint saved as f"{self._directory}/ckpt-{self._epoch_count}"
57
        """
58
        if self._last_save != self._epoch_count:
59
            self._manager.save(checkpoint_number=self._epoch_count)
60
            self._last_save = self._epoch_count
61
62
63
def build_checkpoint_callback(
64
    model: tf.keras.Model,
65
    dataset: tf.data.Dataset,
66
    log_dir: str,
67
    save_period: int,
68
    ckpt_path: str,
69
) -> Tuple[CheckpointManagerCallback, int]:
70
    """
71
    Function to prepare callbacks for training.
72
73
    :param model: model to train
74
    :param dataset: dataset for training
75
    :param log_dir: directory of logs
76
    :param save_period: save the checkpoint every X epochs
77
    :param ckpt_path: path to restore ckpt
78
    :return: a list of callbacks
79
    """
80
    # fit the model for 1 step to initialise optimiser arguments as trackable Variables
81
    model.fit(
82
        x=dataset,
83
        steps_per_epoch=1,
84
        epochs=1,
85
        verbose=0,
86
    )
87
    checkpoint_manager_callback = CheckpointManagerCallback(
88
        model, log_dir + "/save", period=save_period
89
    )
90
    if ckpt_path:
91
        initial_epoch_str = ckpt_path.split("-")[-1]
92
        assert initial_epoch_str.isdigit(), (
93
            f"Checkpoint path for checkpoint manager "
94
            f"must be of form ckpt-epoch_count, got {ckpt_path}"
95
        )
96
        initial_epoch = int(initial_epoch_str)
97
        checkpoint_manager_callback.restore(ckpt_path)
98
    else:
99
        initial_epoch = 0
100
    return checkpoint_manager_callback, initial_epoch
101