Completed
Push — master ( 91b7c0...d52c79 )
by Raphael
01:31
created

ScheduledTrainingServer   A

Complexity

Total Complexity 25

Size/Duplication

Total Lines 150
Duplicated Lines 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
dl 0
loc 150
rs 10
c 1
b 0
f 0
wmc 25

5 Methods

Rating   Name   Duplication   Size   Complexity  
B __init__() 0 32 1
A feed_hyperparams() 0 7 1
A feed_batches() 0 9 2
F handle_control() 0 70 18
A prepare_epoch() 0 21 3
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from __future__ import print_function
5
import time
6
import sys, os
7
import numpy as np
8
import logging as loggers
9
logging = loggers.getLogger("ScheduledTrainingServer")
10
loggers.basicConfig(level=loggers.INFO)
11
12
from platoon.channel import Controller
13
from argparse import ArgumentParser
14
15
16
CONTROLLER_PORT = 5567
17
18
class ScheduledTrainingServer(Controller):
19
    """
20
    This multi-process controller implements patience-based early-stopping SGD
21
    """
22
23
    def __init__(self, port=CONTROLLER_PORT, start_halving_at=5, end_at=10, step_len=10,
24
                 valid_freq = 1000,
25
                 learning_rate = 0.1,
26
                 easgd_alpha=0.5):
27
        """
28
        Initialize the controller.
29
30
        Args:
31
            step_len (int): batches in one training step
32
            config (dict)
33
        """
34
35
        Controller.__init__(self, port)
36
        self.epoch_start_halving = start_halving_at
37
        self.end_at = end_at
38
        self.step_len = step_len
39
        self.start_time = None
40
        self.rand = np.random.RandomState(3)
41
        self.epoch = 0
42
        self._current_iter = 0
43
        self._iters_from_last_valid = 0
44
        self._evaluating = False
45
        self._valid_freq = valid_freq
46
        self._done = False
47
        self._lr = learning_rate
48
        self._easgd_alpha = easgd_alpha
49
50
        self.num_train_batches = 0
51
        self.batch_pool = []
52
        self._train_costs = []
53
        self.prepared_worker_pool = set()
54
        logging.info("multi-gpu server is listening port {}".format(port))
55
56
    def prepare_epoch(self):
57
        """
58
        Prepare for one epoch.
59
        Returns:
60
            bool: False if to stop the training.
61
        """
62
        self.epoch += 1
63
        if self.epoch >= self.epoch_start_halving:
64
            self._lr *= 0.5
65
        self._current_iter = 0
66
        self._iters_from_last_valid = 0
67
        self._train_costs = []
68
        self.prepared_worker_pool.clear()
69
        self.batch_pool = range(self.num_train_batches)
70
        self.rand.shuffle(self.batch_pool)
71
        if self.epoch > self.end_at:
72
            logging.info("Training is done, wait all workers to stop")
73
            return False
74
        else:
75
            logging.info("start epoch {} with lr={}".format(self.epoch, self._lr))
76
            return True
77
78
    def feed_batches(self):
79
        if not self.batch_pool:
80
            return None
81
        else:
82
            batches = self.batch_pool[:self.step_len]
83
            self.batch_pool = self.batch_pool[self.step_len:]
84
            self._current_iter += len(batches)
85
            self._iters_from_last_valid += len(batches)
86
            return batches
87
88
89
    def feed_hyperparams(self):
90
        retval = {
91
            "epoch": self.epoch,
92
            "learning_rate": self._lr,
93
            "easgd_alpha": self._easgd_alpha
94
        }
95
        return retval
96
97
98
    def handle_control(self, req, worker_id):
99
        """
100
        Handles a control_request received from a worker.
101
        Returns:
102
            string or dict: response
103
104
            'stop' - the worker should quit
105
            'wait' - wait for 1 second
106
            'eval' - evaluate on valid and test set to start a new epoch
107
            'sync_hyperparams' - set learning rate
108
            'valid' - evaluate on valid and test set, then save the params
109
            'train' - train next batches
110
        """
111
        if self.start_time is None: self.start_time = time.time()
112
        response = ""
113
114
        if req == 'next':
115
            if self.num_train_batches == 0:
116
                response = "get_num_batches"
117
            elif self._done:
118
                response = "stop"
119
                self.worker_is_done(worker_id)
120
            elif self._evaluating:
121
                response = 'wait'
122
            elif not self.batch_pool:
123
                # End of one iter
124
                response = 'eval'
125
                self._evaluating = True
126
            else:
127
                # Continue training
128
                if worker_id not in self.prepared_worker_pool:
129
                    response = {"sync_hyperparams": self.feed_hyperparams()}
130
                    self.prepared_worker_pool.add(worker_id)
131
                elif self._iters_from_last_valid >= self._valid_freq:
132
                    response = 'valid'
133
                    self._iters_from_last_valid = 0
134
                else:
135
                    response = {"train": self.feed_batches()}
136
        elif 'eval_done' in req:
137
            messages = req['eval_done']
138
            self._evaluating = False
139
            sys.stdout.write("\r")
140
            sys.stdout.flush()
141
            for msg in messages:
142
                logging.info(msg)
143
            continue_training = self.prepare_epoch()
144
            if not continue_training:
145
                self._done = True
146
                logging.info("training time {:.4f}s".format(time.time() - self.start_time))
147
                response = "stop"
148
        elif 'valid_done' in req:
149
            messages = req['valid_done']
150
            sys.stdout.write("\r")
151
            sys.stdout.flush()
152
            for msg in messages:
153
                logging.info(msg)
154
        elif 'train_done' in req:
155
            costs = req['costs']
156
            self._train_costs.append(costs)
157
            sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f" % (self._current_iter * 100 / self.num_train_batches,
158
                                                           costs[0]))
159
            sys.stdout.flush()
160
        elif 'get_num_batches_done' in req:
161
            self.num_train_batches = req['get_num_batches_done']
162
        elif 'get_easgd_alpha' in req:
163
            response = self._easgd_alpha
164
        elif 'sync_hyperparams' in req:
165
            response = {"sync_hyperparams": self.feed_hyperparams()}
166
167
        return response
168
169
if __name__ == '__main__':
170
    ap = ArgumentParser()
171
    ap.add_argument("--port", type=int, default=5567)
172
    ap.add_argument("--learning_rate", type=float, default=0.01)
173
    ap.add_argument("--start_halving_at", type=int, default=5)
174
    ap.add_argument("--end_at", type=int, default=10)
175
    ap.add_argument("--step_len", type=int, default=10)
176
    ap.add_argument("--valid_freq", type=int, default=1500)
177
    ap.add_argument("--easgd_alpha", type=float, default=0.5)
178
    args = ap.parse_args()
179
180
    server = ScheduledTrainingServer(
181
        port=args.port, learning_rate=args.learning_rate,
182
        start_halving_at=args.start_halving_at,
183
        end_at=args.end_at,
184
        step_len=args.step_len,
185
        valid_freq=args.valid_freq,
186
        easgd_alpha=args.easgd_alpha)
187
    server.serve()
188