Completed
Push — master ( 5a91c7...ef4013 )
by Raphael
01:39
created

ScheduledTrainingServer.log()   A

Complexity

Conditions 2

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
c 0
b 0
f 0
dl 0
loc 4
rs 10
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from __future__ import print_function
5
import time
6
import sys, os
7
import numpy as np
8
import logging as loggers
9
from threading import Lock
10
logging = loggers.getLogger("ScheduledTrainingServer")
11
loggers.basicConfig(level=loggers.INFO)
12
13
from platoon.channel import Controller
14
from argparse import ArgumentParser
15
16
17
CONTROLLER_PORT = 5567
18
19
class ScheduledTrainingServer(Controller):
20
    """
21
    This multi-process controller implements patience-based early-stopping SGD
22
    """
23
24
    def __init__(self, port=CONTROLLER_PORT, easgd_alpha=0.5,
25
                 # Following arguments can be received from workers
26
                 start_halving_at=6, end_at=10, step_len=10,
27
                 valid_freq=1500, learning_rate=0.1, log_path=None):
28
        """
29
        Initialize the controller.
30
31
        Args:
32
            port (int): batches in one training step
33
            easgd_alpha (float)
34
        """
35
36
        Controller.__init__(self, port)
37
        self.epoch_start_halving = start_halving_at
38
        self.end_at = end_at
39
        self.step_len = step_len
40
        self.start_time = None
41
        self.rand = np.random.RandomState(3)
42
        self.epoch = 0
43
        self._current_iter = 0
44
        self._iters_from_last_valid = 0
45
        self._evaluating = False
46
        self._valid_freq = valid_freq
47
        self._done = False
48
        self._lr = learning_rate
49
        self._easgd_alpha = easgd_alpha
50
        self._training_names = []
51
        self._evaluation_names = []
52
        self._best_valid_cost = sys.float_info.max
53
        self._lock = Lock()
54
55
        self.num_train_batches = 0
56
        self.batch_pool = []
57
        self._train_costs = []
58
        self._epoch_start_time = None
59
        self.prepared_worker_pool = set()
60
        self.log_file = open(log_path, "w") if log_path else None
61
        if log_path:
62
            logging.info("write logs into {}".format(log_path))
63
        logging.info("multi-gpu server is listening port {}".format(port))
64
65
    def log(self, msg):
66
        logging.info(msg)
67
        if self.log_file:
68
            self.log_file.write(msg + "\n")
69
70
    def prepare_epoch(self):
71
        """
72
        Prepare for one epoch.
73
        Returns:
74
            bool: False if to stop the training.
75
        """
76
        self.epoch += 1
77
        if self.epoch >= self.epoch_start_halving:
78
            self._lr *= 0.5
79
        self._current_iter = 0
80
        self._iters_from_last_valid = 0
81
        self._train_costs = []
82
        self.prepared_worker_pool.clear()
83
        self.batch_pool = range(self.num_train_batches)
84
        self.rand.shuffle(self.batch_pool)
85
        if self.epoch > self.end_at:
86
            self.log("Training is done, wait all workers to stop")
87
            return False
88
        else:
89
            self.log("start epoch {} with lr={}".format(self.epoch, self._lr))
90
            return True
91
92
    def feed_batches(self):
93
        if not self.batch_pool:
94
            return None
95
        else:
96
            batches = self.batch_pool[:self.step_len]
97
            self.batch_pool = self.batch_pool[self.step_len:]
98
            self._current_iter += len(batches)
99
            self._iters_from_last_valid += len(batches)
100
            return batches
101
102
103
    def feed_hyperparams(self):
104
        retval = {
105
            "epoch": self.epoch,
106
            "learning_rate": self._lr,
107
            "easgd_alpha": self._easgd_alpha
108
        }
109
        return retval
110
111
    def get_monitor_string(self, costs):
112
        return " ".join(["{}={:.2f}".format(n, c) for (n, c) in costs])
113
114
115
    def handle_control(self, req, worker_id):
116
        """
117
        Handles a control_request received from a worker.
118
        Returns:
119
            string or dict: response
120
121
            'stop' - the worker should quit
122
            'wait' - wait for 1 second
123
            'eval' - evaluate on valid and test set to start a new epoch
124
            'sync_hyperparams' - set learning rate
125
            'valid' - evaluate on valid and test set, then save the params
126
            'train' - train next batches
127
        """
128
        if self.start_time is None: self.start_time = time.time()
129
        response = ""
130
131
        if req == 'next':
132
            if self.num_train_batches == 0:
133
                response = "get_num_batches"
134
            elif self._done:
135
                response = "stop"
136
                self.worker_is_done(worker_id)
137
            elif self._evaluating:
138
                response = 'wait'
139
            elif not self.batch_pool:
140
                # End of one iter
141
                if self._train_costs:
142
                    with self._lock:
143
                        sys.stdout.write("\r")
144
                        sys.stdout.flush()
145
                        mean_costs = []
146
                        for i in range(len(self._training_names)):
147
                            mean_costs.append(np.mean([c[i] for c in self._train_costs]))
148
                        self.log("train   (epoch={:2d}) {}".format(
149
                            self.epoch,
150
                            self.get_monitor_string(zip(self._training_names, mean_costs)))
151
                        )
152
                response = {'eval': None, 'best_valid_cost': self._best_valid_cost}
153
                self._evaluating = True
154
            else:
155
                # Continue training
156
                if worker_id not in self.prepared_worker_pool:
157
                    response = {"sync_hyperparams": self.feed_hyperparams()}
158
                    self.prepared_worker_pool.add(worker_id)
159
                elif self._iters_from_last_valid >= self._valid_freq:
160
                    response = {'valid': None, 'best_valid_cost': self._best_valid_cost}
161
                    self._iters_from_last_valid = 0
162
                else:
163
                    response = {"train": self.feed_batches()}
164
        elif 'eval_done' in req:
165
            with self._lock:
166
                self._evaluating = False
167
                sys.stdout.write("\r")
168
                sys.stdout.flush()
169
                if 'test_costs' in req and req['test_costs']:
170
                    self.log("test    (epoch={:2d}) {}".format(
171
                        self.epoch,
172
                        self.get_monitor_string(req['test_costs']))
173
                    )
174
                if 'valid_costs' in req and req['test_costs']:
175
                    valid_J = req['valid_costs'][0][1]
176
                    if valid_J < self._best_valid_cost:
177
                        self._best_valid_cost = valid_J
178
                        star_str = "*"
179
                    else:
180
                        star_str = ""
181
                        self.log("valid   (epoch={:2d}) {} {} (worker {})".format(
182
                        self.epoch,
183
                        self.get_monitor_string(req['valid_costs']),
184
                        star_str,
185
                        worker_id))
186
                    # if star_str and 'auto_save' in req and req['auto_save']:
187
                    #     self.log("(worker {}) save the model to {}".format(
188
                    #         worker_id,
189
                    #         req['auto_save']
190
                    #     ))
191
                continue_training = self.prepare_epoch()
192
                self._epoch_start_time = time.time()
193
                if not continue_training:
194
                    self._done = True
195
                    self.log("training time {:.4f}s".format(time.time() - self.start_time))
196
                    response = "stop"
197
        elif 'valid_done' in req:
198
            with self._lock:
199
                sys.stdout.write("\r")
200
                sys.stdout.flush()
201
                if 'valid_costs' in req:
202
                    valid_J = req['valid_costs'][0][1]
203
                    if valid_J < self._best_valid_cost:
204
                        self._best_valid_cost = valid_J
205
                        star_str = "*"
206
                    else:
207
                        star_str = ""
208
                    self.log("valid   ( dryrun ) {} {} (worker {})".format(
209
                        self.get_monitor_string(req['valid_costs']),
210
                        star_str,
211
                        worker_id
212
                    ))
213
                    # if star_str and 'auto_save' in req and req['auto_save']:
214
                    #     self.log("(worker {}) save the model to {}".format(
215
                    #         worker_id,
216
                    #         req['auto_save']
217
                    #     ))
218
        elif 'train_done' in req:
219
            costs = req['costs']
220
            self._train_costs.append(costs)
221
            sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f | %.1f batch/s" % (
222
                self._current_iter * 100 / self.num_train_batches,
223
                costs[0], float(len(self._train_costs)*self.step_len)/(time.time() - self._epoch_start_time)))
224
            sys.stdout.flush()
225
        elif 'get_num_batches_done' in req:
226
            self.num_train_batches = req['get_num_batches_done']
227
        elif 'get_easgd_alpha' in req:
228
            response = self._easgd_alpha
229
        elif 'sync_hyperparams' in req:
230
            response = {"sync_hyperparams": self.feed_hyperparams()}
231
        elif 'init_schedule' in req:
232
            with self._lock:
233
                sys.stdout.write("\r")
234
                sys.stdout.flush()
235
                self.log("worker {} connected".format(worker_id))
236
                if self.epoch == 0:
237
                    schedule_params = req['init_schedule']
238
                    sch_str = " ".join("{}={}".format(a, b) for (a, b) in schedule_params.items())
239
                    self.log("initialize the schedule with {}".format(sch_str))
240
                    for key, val in schedule_params.items():
241
                        if not val: continue
242
                        if key == 'learning_rate':
243
                            self._lr = val
244
                        elif key == 'start_halving_at':
245
                            self.epoch_start_halving = val
246
                        elif key == 'end_at':
247
                            self.end_at = val
248
                        elif key == 'step_len':
249
                            self.step_len = val
250
                        elif key == 'valid_freq':
251
                            self._valid_freq = val
252
253
        elif 'set_names' in req:
254
            self._training_names = req['training_names']
255
            self._evaluation_names = req['evaluation_names']
256
257
258
        return response
259
260
if __name__ == '__main__':
261
    ap = ArgumentParser()
262
    ap.add_argument("--port", type=int, default=5567)
263
    ap.add_argument("--easgd_alpha", type=float, default=0.5)
264
    ap.add_argument("--log", type=str, default=None)
265
    args = ap.parse_args()
266
267
    server = ScheduledTrainingServer(
268
        port=args.port,
269
        easgd_alpha=args.easgd_alpha,
270
        log_path=args.log
271
    )
272
    server.serve()
273