Completed
Push — master ( 394368...090fba )
by Raphael
01:33
created

TrainingValidator.compare()   B

Complexity

Conditions 6

Size

Total Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
c 0
b 0
f 0
dl 0
loc 18
rs 8
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from collections import OrderedDict, defaultdict
5
6
class TrainingController(object):
7
    """
8
    Abstract class of training controllers.
9
    """
10
11
    def bind(self, trainer):
12
        """
13
        :type trainer: deepy.trainers.base.NeuralTrainer
14
        """
15
        self._trainer = trainer
16
17
    def invoke(self):
18
        """
19
        Return True to exit training.
20
        """
21
        return False
22
23
class TrainingValidator(TrainingController):
24
    """
25
    A validator that allows validating the model with another graph.
26
    """
27
28
    def __init__(self, valid_model=None, data_split='valid', freq=1500, save_path=None, criteria='cost',
29
                 smaller_is_better=True):
30
        """
31
        Initialize the training validator.
32
        """
33
        self._model = valid_model
34
        self._data_split = data_split
35
        self._freq = freq
36
        self._save_path = save_path
37
        self._criteria = criteria
38
        self._smaller_is_better = smaller_is_better
39
        self._best_criteria = None
40
        self._counter = 0
41
42
    def compare(self, cost_map):
43
        """
44
        Compare to previous records and return whether the given cost is a new best.
45
        :return: True if the given cost is a new best
46
        """
47
        cri_val = cost_map[self._criteria]
48
        if self._best_criteria is None:
49
            self._best_criteria = cri_val
50
            return True
51
        else:
52
            if self._smaller_is_better and cri_val < self._best_criteria:
53
                self._best_criteria = cri_val
54
                return True
55
            elif not self._smaller_is_better and cri_val > self._best_criteria:
56
                self._best_criteria = cri_val
57
                return True
58
            else:
59
                return False
60
61
    def compute(self, *x):
62
        """
63
        Compute with the validation model given data x.
64
        """
65
        return self._model.compute(*x)
66
67
    def _extract_costs(self, vars):
68
        ret_map = OrderedDict()
69
        sub_costs = OrderedDict()
70
        for k, val in vars.items():
71
            if val.ndim == 0:
72
                if k == self._criteria:
73
                    ret_map[k] = val
74
                else:
75
                    sub_costs[k] = val
76
        ret_map.update(sub_costs)
77
        return ret_map
78
79
    def run(self, data_x):
80
        """
81
        Run the model with validation data and return costs.
82
        """
83
        output_vars = self.compute(*data_x)
84
        return self._extract_costs(output_vars)
85
86
    def invoke(self):
87
        """
88
        This function will be called after each iteration.
89
        """
90
        self._counter += 1
91
        if self._counter % self._freq == 0:
92
            cnt = 0.
93
            sum_map = defaultdict(float)
94
            for x in self._trainer.get_data(self._data_split):
95
                val_map = self.run(x)
96
                if not isinstance(val_map, dict):
97
                    raise Exception("Monitor.run must return a dict.")
98
                for k, val in val_map.items():
99
                    sum_map[k] += val
100
                cnt += 1
101
            for k in sum_map:
102
                sum_map[k] /= cnt
103
            new_best = self.compare(sum_map)
104
            self._trainer.report(sum_map, self._data_split, new_best=new_best)
105
            if new_best:
106
                self._trainer.save_checkpoint(self._save_path)