Completed
Push — master ( 5a91c7...ef4013 )
by Raphael
01:39
created

deepy/multigpu/server.py (2 issues)

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from __future__ import print_function
5
import time
6
import sys, os
7
import numpy as np
8
import logging as loggers
9
from threading import Lock
10
logging = loggers.getLogger("ScheduledTrainingServer")
11
loggers.basicConfig(level=loggers.INFO)
12
13
from platoon.channel import Controller
14
from argparse import ArgumentParser
15
16
17
CONTROLLER_PORT = 5567
18
19
class ScheduledTrainingServer(Controller):
20
    """
21
    This multi-process controller implements patience-based early-stopping SGD
22
    """
23
24
    def __init__(self, port=CONTROLLER_PORT, easgd_alpha=0.5,
25
                 # Following arguments can be received from workers
26
                 start_halving_at=6, end_at=10, step_len=10,
27
                 valid_freq = 1500,
28
                 learning_rate = 0.1):
29
        """
30
        Initialize the controller.
31
32
        Args:
33
            port (int): batches in one training step
34
            easgd_alpha (float)
35
        """
36
37
        Controller.__init__(self, port)
38
        self.epoch_start_halving = start_halving_at
39
        self.end_at = end_at
40
        self.step_len = step_len
41
        self.start_time = None
42
        self.rand = np.random.RandomState(3)
43
        self.epoch = 0
44
        self._current_iter = 0
45
        self._iters_from_last_valid = 0
46
        self._evaluating = False
47
        self._valid_freq = valid_freq
48
        self._done = False
49
        self._lr = learning_rate
50
        self._easgd_alpha = easgd_alpha
51
        self._training_names = []
52
        self._evaluation_names = []
53
        self._best_valid_cost = sys.float_info.max
54
        self._lock = Lock()
55
56
        self.num_train_batches = 0
57
        self.batch_pool = []
58
        self._train_costs = []
59
        self.prepared_worker_pool = set()
60
        logging.info("multi-gpu server is listening port {}".format(port))
61
62
    def prepare_epoch(self):
63
        """
64
        Prepare for one epoch.
65
        Returns:
66
            bool: False if to stop the training.
67
        """
68
        self.epoch += 1
69
        if self.epoch >= self.epoch_start_halving:
70
            self._lr *= 0.5
71
        self._current_iter = 0
72
        self._iters_from_last_valid = 0
73
        self._train_costs = []
74
        self.prepared_worker_pool.clear()
75
        self.batch_pool = range(self.num_train_batches)
76
        self.rand.shuffle(self.batch_pool)
77
        if self.epoch > self.end_at:
78
            logging.info("Training is done, wait all workers to stop")
79
            return False
80
        else:
81
            logging.info("start epoch {} with lr={}".format(self.epoch, self._lr))
82
            return True
83
84
    def feed_batches(self):
85
        if not self.batch_pool:
86
            return None
87
        else:
88
            batches = self.batch_pool[:self.step_len]
89
            self.batch_pool = self.batch_pool[self.step_len:]
90
            self._current_iter += len(batches)
91
            self._iters_from_last_valid += len(batches)
92
            return batches
93
94
95
    def feed_hyperparams(self):
96
        retval = {
97
            "epoch": self.epoch,
98
            "learning_rate": self._lr,
99
            "easgd_alpha": self._easgd_alpha
100
        }
101
        return retval
102
103
    def get_monitor_string(self, costs):
104
        return " ".join(["{}={:.2f}".format(n, c) for (n, c) in costs])
105
106
107
    def handle_control(self, req, worker_id):
108
        """
109
        Handles a control_request received from a worker.
110
        Returns:
111
            string or dict: response
112
113
            'stop' - the worker should quit
114
            'wait' - wait for 1 second
115
            'eval' - evaluate on valid and test set to start a new epoch
116
            'sync_hyperparams' - set learning rate
117
            'valid' - evaluate on valid and test set, then save the params
118
            'train' - train next batches
119
        """
120
        if self.start_time is None: self.start_time = time.time()
121
        response = ""
122
123
        if req == 'next':
124
            if self.num_train_batches == 0:
125
                response = "get_num_batches"
126
            elif self._done:
127
                response = "stop"
128
                self.worker_is_done(worker_id)
129
            elif self._evaluating:
130
                response = 'wait'
131
            elif not self.batch_pool:
132
                # End of one iter
133
                if self._train_costs:
134
                    with self._lock:
135
                        sys.stdout.write("\r")
136
                        sys.stdout.flush()
137
                        mean_costs = []
138
                        for i in range(len(self._training_names)):
139
                            mean_costs.append(np.mean([c[i] for c in self._train_costs]))
140
                        logging.info("train   (epoch={:2d}) {}".format(
141
                            self.epoch,
142
                            self.get_monitor_string(zip(self._training_names, mean_costs)))
143
                        )
144
                response = {'eval': None, 'best_valid_cost': self._best_valid_cost}
145
                self._evaluating = True
146
            else:
147
                # Continue training
148
                if worker_id not in self.prepared_worker_pool:
149
                    response = {"sync_hyperparams": self.feed_hyperparams()}
150
                    self.prepared_worker_pool.add(worker_id)
151
                elif self._iters_from_last_valid >= self._valid_freq:
152
                    response = {'valid': None, 'best_valid_cost': self._best_valid_cost}
153
                    self._iters_from_last_valid = 0
154
                else:
155
                    response = {"train": self.feed_batches()}
156
        elif 'eval_done' in req:
157
            with self._lock:
158
                self._evaluating = False
159
                sys.stdout.write("\r")
160
                sys.stdout.flush()
161
                if 'test_costs' in req and req['test_costs']:
162
                    logging.info("test    (epoch={:2d}) {}".format(
163
                        self.epoch,
164
                        self.get_monitor_string(req['test_costs']))
165
                    )
166 View Code Duplication
                if 'valid_costs' in req and req['test_costs']:
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
167
                    valid_J = req['valid_costs'][0][1]
168
                    if valid_J < self._best_valid_cost:
169
                        self._best_valid_cost = valid_J
170
                        star_str = "*"
171
                    else:
172
                        star_str = ""
173
                    logging.info("valid   (epoch={:2d}) {} {}".format(
174
                        self.epoch,
175
                        self.get_monitor_string(req['valid_costs']),
176
                        star_str))
177
                    if star_str and 'auto_save' in req and req['auto_save']:
178
                        logging.info("(worker {}) save the model to {}".format(
179
                            worker_id,
180
                            req['auto_save']
181
                        ))
182
                continue_training = self.prepare_epoch()
183
                if not continue_training:
184
                    self._done = True
185
                    logging.info("training time {:.4f}s".format(time.time() - self.start_time))
186
                    response = "stop"
187
        elif 'valid_done' in req:
188
            with self._lock:
189
                sys.stdout.write("\r")
190
                sys.stdout.flush()
191 View Code Duplication
                if 'valid_costs' in req:
0 ignored issues
show
This code seems to be duplicated in your project.
Loading history...
192
                    valid_J = req['valid_costs'][0][1]
193
                    if valid_J < self._best_valid_cost:
194
                        self._best_valid_cost = valid_J
195
                        star_str = "*"
196
                    else:
197
                        star_str = ""
198
                    logging.info("valid   ( dryrun ) {} {}".format(
199
                        self.get_monitor_string(req['valid_costs']),
200
                        star_str
201
                    ))
202
                    if star_str and 'auto_save' in req and req['auto_save']:
203
                        logging.info("(worker {}) save the model to {}".format(
204
                            worker_id,
205
                            req['auto_save']
206
                        ))
207
        elif 'train_done' in req:
208
            costs = req['costs']
209
            self._train_costs.append(costs)
210
            sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f" % (self._current_iter * 100 / self.num_train_batches,
211
                                                           costs[0]))
212
            sys.stdout.flush()
213
        elif 'get_num_batches_done' in req:
214
            self.num_train_batches = req['get_num_batches_done']
215
        elif 'get_easgd_alpha' in req:
216
            response = self._easgd_alpha
217
        elif 'sync_hyperparams' in req:
218
            response = {"sync_hyperparams": self.feed_hyperparams()}
219
        elif 'init_schedule' in req:
220
            with self._lock:
221
                sys.stdout.write("\r")
222
                sys.stdout.flush()
223
                logging.info("worker {} connected".format(worker_id))
224
                if self.epoch == 0:
225
                    schedule_params = req['init_schedule']
226
                    sch_str = " ".join("{}={}".format(a, b) for (a, b) in schedule_params.items())
227
                    logging.info("initialize the schedule with {}".format(sch_str))
228
                    for key, val in schedule_params.items():
229
                        if not val: continue
230
                        if key == 'learning_rate':
231
                            self._lr = val
232
                        elif key == 'start_halving_at':
233
                            self.epoch_start_halving = val
234
                        elif key == 'end_at':
235
                            self.end_at = val
236
                        elif key == 'step_len':
237
                            self.step_len = val
238
                        elif key == 'valid_freq':
239
                            self._valid_freq = val
240
241
        elif 'set_names' in req:
242
            self._training_names = req['training_names']
243
            self._evaluation_names = req['evaluation_names']
244
245
246
        return response
247
248
if __name__ == '__main__':
249
    ap = ArgumentParser()
250
    ap.add_argument("--port", type=int, default=5567)
251
    ap.add_argument("--easgd_alpha", type=float, default=0.5)
252
    args = ap.parse_args()
253
254
    server = ScheduledTrainingServer(
255
        port=args.port,
256
        easgd_alpha=args.easgd_alpha)
257
    server.serve()
258