Completed
Push — master ( d52c79...139da6 )
by Raphael
01:29
created

ScheduledTrainingServer   B

Complexity

Total Complexity 37

Size/Duplication

Total Lines 194
Duplicated Lines 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
c 1
b 0
f 0
dl 0
loc 194
rs 8.6
wmc 37

6 Methods

Rating   Name   Duplication   Size   Complexity  
B __init__() 0 36 1
A get_monitor_string() 0 2 2
A feed_hyperparams() 0 7 1
A feed_batches() 0 9 2
F handle_control() 0 107 28
A prepare_epoch() 0 21 3
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from __future__ import print_function
5
import time
6
import sys, os
7
import numpy as np
8
import logging as loggers
9
from threading import Lock
10
logging = loggers.getLogger("ScheduledTrainingServer")
11
loggers.basicConfig(level=loggers.INFO)
12
13
from platoon.channel import Controller
14
from argparse import ArgumentParser
15
16
17
CONTROLLER_PORT = 5567
18
19
class ScheduledTrainingServer(Controller):
20
    """
21
    This multi-process controller implements patience-based early-stopping SGD
22
    """
23
24
    def __init__(self, port=CONTROLLER_PORT, start_halving_at=5, end_at=10, step_len=10,
25
                 valid_freq = 1000,
26
                 learning_rate = 0.1,
27
                 easgd_alpha=0.5):
28
        """
29
        Initialize the controller.
30
31
        Args:
32
            step_len (int): batches in one training step
33
            config (dict)
34
        """
35
36
        Controller.__init__(self, port)
37
        self.epoch_start_halving = start_halving_at
38
        self.end_at = end_at
39
        self.step_len = step_len
40
        self.start_time = None
41
        self.rand = np.random.RandomState(3)
42
        self.epoch = 0
43
        self._current_iter = 0
44
        self._iters_from_last_valid = 0
45
        self._evaluating = False
46
        self._valid_freq = valid_freq
47
        self._done = False
48
        self._lr = learning_rate
49
        self._easgd_alpha = easgd_alpha
50
        self._training_names = []
51
        self._evaluation_names = []
52
        self._best_valid_cost = sys.float_info.max
53
        self._lock = Lock()
54
55
        self.num_train_batches = 0
56
        self.batch_pool = []
57
        self._train_costs = []
58
        self.prepared_worker_pool = set()
59
        logging.info("multi-gpu server is listening port {}".format(port))
60
61
    def prepare_epoch(self):
62
        """
63
        Prepare for one epoch.
64
        Returns:
65
            bool: False if to stop the training.
66
        """
67
        self.epoch += 1
68
        if self.epoch >= self.epoch_start_halving:
69
            self._lr *= 0.5
70
        self._current_iter = 0
71
        self._iters_from_last_valid = 0
72
        self._train_costs = []
73
        self.prepared_worker_pool.clear()
74
        self.batch_pool = range(self.num_train_batches)
75
        self.rand.shuffle(self.batch_pool)
76
        if self.epoch > self.end_at:
77
            logging.info("Training is done, wait all workers to stop")
78
            return False
79
        else:
80
            logging.info("start epoch {} with lr={}".format(self.epoch, self._lr))
81
            return True
82
83
    def feed_batches(self):
84
        if not self.batch_pool:
85
            return None
86
        else:
87
            batches = self.batch_pool[:self.step_len]
88
            self.batch_pool = self.batch_pool[self.step_len:]
89
            self._current_iter += len(batches)
90
            self._iters_from_last_valid += len(batches)
91
            return batches
92
93
94
    def feed_hyperparams(self):
95
        retval = {
96
            "epoch": self.epoch,
97
            "learning_rate": self._lr,
98
            "easgd_alpha": self._easgd_alpha
99
        }
100
        return retval
101
102
    def get_monitor_string(self, costs):
103
        return " ".join(["{}={:.2f}".format(n, c) for (n, c) in costs])
104
105
106
    def handle_control(self, req, worker_id):
107
        """
108
        Handles a control_request received from a worker.
109
        Returns:
110
            string or dict: response
111
112
            'stop' - the worker should quit
113
            'wait' - wait for 1 second
114
            'eval' - evaluate on valid and test set to start a new epoch
115
            'sync_hyperparams' - set learning rate
116
            'valid' - evaluate on valid and test set, then save the params
117
            'train' - train next batches
118
        """
119
        if self.start_time is None: self.start_time = time.time()
120
        response = ""
121
122
        if req == 'next':
123
            if self.num_train_batches == 0:
124
                response = "get_num_batches"
125
            elif self._done:
126
                response = "stop"
127
                self.worker_is_done(worker_id)
128
            elif self._evaluating:
129
                response = 'wait'
130
            elif not self.batch_pool:
131
                # End of one iter
132
                if self._train_costs:
133
                    with self._lock:
134
                        sys.stdout.write("\r")
135
                        sys.stdout.flush()
136
                        mean_costs = []
137
                        for i in range(len(self._training_names)):
138
                            mean_costs.append(np.mean([c[i] for c in self._train_costs]))
139
                        logging.info("train   (epoch={:2d}) {}".format(
140
                            self.epoch,
141
                            self.get_monitor_string(zip(self._training_names, mean_costs)))
142
                        )
143
                response = {'eval': None, 'best_valid_cost': self._best_valid_cost}
144
                self._evaluating = True
145
            else:
146
                # Continue training
147
                if worker_id not in self.prepared_worker_pool:
148
                    response = {"sync_hyperparams": self.feed_hyperparams()}
149
                    self.prepared_worker_pool.add(worker_id)
150
                elif self._iters_from_last_valid >= self._valid_freq:
151
                    response = {'valid': None, 'best_valid_cost': self._best_valid_cost}
152
                    self._iters_from_last_valid = 0
153
                else:
154
                    response = {"train": self.feed_batches()}
155
        elif 'eval_done' in req:
156
            with self._lock:
157
                self._evaluating = False
158
                sys.stdout.write("\r")
159
                sys.stdout.flush()
160
                if 'test_costs' in req:
161
                    logging.info("test    (epoch={:2d}) {}".format(
162
                        self.epoch,
163
                        self.get_monitor_string(req['test_costs']))
164
                    )
165
                if 'valid_costs' in req:
166
                    valid_J = req['valid_costs'][0][1]
167
                    if valid_J < self._best_valid_cost:
168
                        self._best_valid_cost = valid_J
169
                        star_str = "*"
170
                    else:
171
                        star_str = ""
172
                    logging.info("valid   (epoch={:2d}) {} {}".format(
173
                        self.epoch,
174
                        self.get_monitor_string(req['valid_costs']),
175
                        star_str))
176
                continue_training = self.prepare_epoch()
177
                if not continue_training:
178
                    self._done = True
179
                    logging.info("training time {:.4f}s".format(time.time() - self.start_time))
180
                    response = "stop"
181
        elif 'valid_done' in req:
182
            with self._lock:
183
                sys.stdout.write("\r")
184
                sys.stdout.flush()
185
                if 'valid_costs' in req:
186
                    valid_J = req['valid_costs'][0][1]
187
                    if valid_J < self._best_valid_cost:
188
                        self._best_valid_cost = valid_J
189
                        star_str = "*"
190
                    else:
191
                        star_str = ""
192
                    logging.info("valid   ( dryrun ) {} {}".format(
193
                        self.get_monitor_string(req['valid_costs']),
194
                        star_str
195
                    ))
196
        elif 'train_done' in req:
197
            costs = req['costs']
198
            self._train_costs.append(costs)
199
            sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f" % (self._current_iter * 100 / self.num_train_batches,
200
                                                           costs[0]))
201
            sys.stdout.flush()
202
        elif 'get_num_batches_done' in req:
203
            self.num_train_batches = req['get_num_batches_done']
204
        elif 'get_easgd_alpha' in req:
205
            response = self._easgd_alpha
206
        elif 'sync_hyperparams' in req:
207
            response = {"sync_hyperparams": self.feed_hyperparams()}
208
        elif 'set_names' in req:
209
            self._training_names = req['training_names']
210
            self._evaluation_names = req['evaluation_names']
211
212
        return response
213
214
if __name__ == '__main__':
215
    ap = ArgumentParser()
216
    ap.add_argument("--port", type=int, default=5567)
217
    ap.add_argument("--learning_rate", type=float, default=0.01)
218
    ap.add_argument("--start_halving_at", type=int, default=5)
219
    ap.add_argument("--end_at", type=int, default=10)
220
    ap.add_argument("--step_len", type=int, default=10)
221
    ap.add_argument("--valid_freq", type=int, default=1500)
222
    ap.add_argument("--easgd_alpha", type=float, default=0.5)
223
    args = ap.parse_args()
224
225
    server = ScheduledTrainingServer(
226
        port=args.port, learning_rate=args.learning_rate,
227
        start_halving_at=args.start_halving_at,
228
        end_at=args.end_at,
229
        step_len=args.step_len,
230
        valid_freq=args.valid_freq,
231
        easgd_alpha=args.easgd_alpha)
232
    server.serve()
233