Completed
Push — master ( c48f07...4ce1c1 )
by Raphael
01:33
created

TrainLogger.load()   B

Complexity

Conditions 5

Size

Total Lines 12

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 5
c 2
b 0
f 0
dl 0
loc 12
rs 8.5454
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import os
5
import datetime
6
import logging as loggers
7
import deepy
8
logging = loggers.getLogger(__name__)
9
10
PROGRESS_LOG_PREFIX = "progress:"
11
EPOCH_LOG_PREFIX = "epoch:"
12
13
class TrainLogger(object):
14
15
    def __init__(self):
16
        self.log_pool = []
17
        self._progress = 0
18
        self._epoch = 0
19
20
    def load(self, model_path):
21
        log_path = self._log_path(model_path)
22
        if os.path.exists(log_path):
23
            logging.info("Load training log from %s" % log_path)
24
            for line in open(log_path).xreadlines():
25
                if line.startswith(EPOCH_LOG_PREFIX):
26
                    self._epoch = int(line.replace(EPOCH_LOG_PREFIX, "").strip())
27
                    continue
28
                if line.startswith(PROGRESS_LOG_PREFIX):
29
                    self._progress = int(line.replace(PROGRESS_LOG_PREFIX, "").strip())
30
                    continue
31
                self.log_pool.append(line.strip())
32
33
    def record(self, line):
34
        time_mark = datetime.datetime.now().strftime("[%Y/%m/%d %H:%M:%S] ")
35
        self.log_pool.append(time_mark + line)
36
37
    def record_progress(self, progress):
38
        """
39
        Record current progress in the training[1;3B].
40
        """
41
        self._progress = progress
42
43
    def record_epoch(self, epoch):
44
        self._epoch = epoch
45
46
    def progress(self):
47
        """
48
        Get loaded progress.
49
        """
50
        return self._progress
51
52
    def epoch(self):
53
        return self._epoch
54
55
    def save(self, model_path):
56
        log_path = self._log_path(model_path)
57
        # logging.info("Save training log to %s" % log_path)
58
        with open(log_path, "w") as outf:
59
            outf.write("# deepy version: %s\n" % deepy.__version__)
60
            for line in self.log_pool:
61
                outf.write(line + "\n")
62
            if self._epoch > 0:
63
                outf.write("%s %d\n" % (EPOCH_LOG_PREFIX, self._epoch))
64
            if self._progress > 0:
65
                outf.write("%s %d\n" % (PROGRESS_LOG_PREFIX, self._progress))
66
67
    def _log_path(self, model_path):
68
        log_path = model_path.rsplit(".", 1)[0] + ".log"
69
        return log_path
70