1 | #!/usr/bin/env python |
||
2 | # -*- coding: utf-8 -*- |
||
3 | |||
4 | from __future__ import print_function |
||
5 | import time |
||
6 | import sys, os |
||
7 | import numpy as np |
||
8 | import logging as loggers |
||
9 | from threading import Lock |
||
10 | logging = loggers.getLogger("ScheduledTrainingServer") |
||
11 | loggers.basicConfig(level=loggers.INFO) |
||
12 | |||
13 | from platoon.channel import Controller |
||
14 | from argparse import ArgumentParser |
||
15 | |||
16 | |||
17 | CONTROLLER_PORT = 5567 |
||
18 | |||
19 | class ScheduledTrainingServer(Controller): |
||
20 | """ |
||
21 | This multi-process controller implements patience-based early-stopping SGD |
||
22 | """ |
||
23 | |||
24 | def __init__(self, port=CONTROLLER_PORT, easgd_alpha=0.5, |
||
25 | # Following arguments can be received from workers |
||
26 | start_halving_at=6, end_at=10, step_len=10, |
||
27 | valid_freq = 1500, |
||
28 | learning_rate = 0.1): |
||
29 | """ |
||
30 | Initialize the controller. |
||
31 | |||
32 | Args: |
||
33 | port (int): batches in one training step |
||
34 | easgd_alpha (float) |
||
35 | """ |
||
36 | |||
37 | Controller.__init__(self, port) |
||
38 | self.epoch_start_halving = start_halving_at |
||
39 | self.end_at = end_at |
||
40 | self.step_len = step_len |
||
41 | self.start_time = None |
||
42 | self.rand = np.random.RandomState(3) |
||
43 | self.epoch = 0 |
||
44 | self._current_iter = 0 |
||
45 | self._iters_from_last_valid = 0 |
||
46 | self._evaluating = False |
||
47 | self._valid_freq = valid_freq |
||
48 | self._done = False |
||
49 | self._lr = learning_rate |
||
50 | self._easgd_alpha = easgd_alpha |
||
51 | self._training_names = [] |
||
52 | self._evaluation_names = [] |
||
53 | self._best_valid_cost = sys.float_info.max |
||
54 | self._lock = Lock() |
||
55 | |||
56 | self.num_train_batches = 0 |
||
57 | self.batch_pool = [] |
||
58 | self._train_costs = [] |
||
59 | self.prepared_worker_pool = set() |
||
60 | logging.info("multi-gpu server is listening port {}".format(port)) |
||
61 | |||
62 | def prepare_epoch(self): |
||
63 | """ |
||
64 | Prepare for one epoch. |
||
65 | Returns: |
||
66 | bool: False if to stop the training. |
||
67 | """ |
||
68 | self.epoch += 1 |
||
69 | if self.epoch >= self.epoch_start_halving: |
||
70 | self._lr *= 0.5 |
||
71 | self._current_iter = 0 |
||
72 | self._iters_from_last_valid = 0 |
||
73 | self._train_costs = [] |
||
74 | self.prepared_worker_pool.clear() |
||
75 | self.batch_pool = range(self.num_train_batches) |
||
76 | self.rand.shuffle(self.batch_pool) |
||
77 | if self.epoch > self.end_at: |
||
78 | logging.info("Training is done, wait all workers to stop") |
||
79 | return False |
||
80 | else: |
||
81 | logging.info("start epoch {} with lr={}".format(self.epoch, self._lr)) |
||
82 | return True |
||
83 | |||
84 | def feed_batches(self): |
||
85 | if not self.batch_pool: |
||
86 | return None |
||
87 | else: |
||
88 | batches = self.batch_pool[:self.step_len] |
||
89 | self.batch_pool = self.batch_pool[self.step_len:] |
||
90 | self._current_iter += len(batches) |
||
91 | self._iters_from_last_valid += len(batches) |
||
92 | return batches |
||
93 | |||
94 | |||
95 | def feed_hyperparams(self): |
||
96 | retval = { |
||
97 | "epoch": self.epoch, |
||
98 | "learning_rate": self._lr, |
||
99 | "easgd_alpha": self._easgd_alpha |
||
100 | } |
||
101 | return retval |
||
102 | |||
103 | def get_monitor_string(self, costs): |
||
104 | return " ".join(["{}={:.2f}".format(n, c) for (n, c) in costs]) |
||
105 | |||
106 | |||
107 | def handle_control(self, req, worker_id): |
||
108 | """ |
||
109 | Handles a control_request received from a worker. |
||
110 | Returns: |
||
111 | string or dict: response |
||
112 | |||
113 | 'stop' - the worker should quit |
||
114 | 'wait' - wait for 1 second |
||
115 | 'eval' - evaluate on valid and test set to start a new epoch |
||
116 | 'sync_hyperparams' - set learning rate |
||
117 | 'valid' - evaluate on valid and test set, then save the params |
||
118 | 'train' - train next batches |
||
119 | """ |
||
120 | if self.start_time is None: self.start_time = time.time() |
||
121 | response = "" |
||
122 | |||
123 | if req == 'next': |
||
124 | if self.num_train_batches == 0: |
||
125 | response = "get_num_batches" |
||
126 | elif self._done: |
||
127 | response = "stop" |
||
128 | self.worker_is_done(worker_id) |
||
129 | elif self._evaluating: |
||
130 | response = 'wait' |
||
131 | elif not self.batch_pool: |
||
132 | # End of one iter |
||
133 | if self._train_costs: |
||
134 | with self._lock: |
||
135 | sys.stdout.write("\r") |
||
136 | sys.stdout.flush() |
||
137 | mean_costs = [] |
||
138 | for i in range(len(self._training_names)): |
||
139 | mean_costs.append(np.mean([c[i] for c in self._train_costs])) |
||
140 | logging.info("train (epoch={:2d}) {}".format( |
||
141 | self.epoch, |
||
142 | self.get_monitor_string(zip(self._training_names, mean_costs))) |
||
143 | ) |
||
144 | response = {'eval': None, 'best_valid_cost': self._best_valid_cost} |
||
145 | self._evaluating = True |
||
146 | else: |
||
147 | # Continue training |
||
148 | if worker_id not in self.prepared_worker_pool: |
||
149 | response = {"sync_hyperparams": self.feed_hyperparams()} |
||
150 | self.prepared_worker_pool.add(worker_id) |
||
151 | elif self._iters_from_last_valid >= self._valid_freq: |
||
152 | response = {'valid': None, 'best_valid_cost': self._best_valid_cost} |
||
153 | self._iters_from_last_valid = 0 |
||
154 | else: |
||
155 | response = {"train": self.feed_batches()} |
||
156 | elif 'eval_done' in req: |
||
157 | with self._lock: |
||
158 | self._evaluating = False |
||
159 | sys.stdout.write("\r") |
||
160 | sys.stdout.flush() |
||
161 | if 'test_costs' in req and req['test_costs']: |
||
162 | logging.info("test (epoch={:2d}) {}".format( |
||
163 | self.epoch, |
||
164 | self.get_monitor_string(req['test_costs'])) |
||
165 | ) |
||
166 | View Code Duplication | if 'valid_costs' in req and req['test_costs']: |
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
167 | valid_J = req['valid_costs'][0][1] |
||
168 | if valid_J < self._best_valid_cost: |
||
169 | self._best_valid_cost = valid_J |
||
170 | star_str = "*" |
||
171 | else: |
||
172 | star_str = "" |
||
173 | logging.info("valid (epoch={:2d}) {} {}".format( |
||
174 | self.epoch, |
||
175 | self.get_monitor_string(req['valid_costs']), |
||
176 | star_str)) |
||
177 | if star_str and 'auto_save' in req and req['auto_save']: |
||
178 | logging.info("(worker {}) save the model to {}".format( |
||
179 | worker_id, |
||
180 | req['auto_save'] |
||
181 | )) |
||
182 | continue_training = self.prepare_epoch() |
||
183 | if not continue_training: |
||
184 | self._done = True |
||
185 | logging.info("training time {:.4f}s".format(time.time() - self.start_time)) |
||
186 | response = "stop" |
||
187 | elif 'valid_done' in req: |
||
188 | with self._lock: |
||
189 | sys.stdout.write("\r") |
||
190 | sys.stdout.flush() |
||
191 | View Code Duplication | if 'valid_costs' in req: |
|
0 ignored issues
–
show
|
|||
192 | valid_J = req['valid_costs'][0][1] |
||
193 | if valid_J < self._best_valid_cost: |
||
194 | self._best_valid_cost = valid_J |
||
195 | star_str = "*" |
||
196 | else: |
||
197 | star_str = "" |
||
198 | logging.info("valid ( dryrun ) {} {}".format( |
||
199 | self.get_monitor_string(req['valid_costs']), |
||
200 | star_str |
||
201 | )) |
||
202 | if star_str and 'auto_save' in req and req['auto_save']: |
||
203 | logging.info("(worker {}) save the model to {}".format( |
||
204 | worker_id, |
||
205 | req['auto_save'] |
||
206 | )) |
||
207 | elif 'train_done' in req: |
||
208 | costs = req['costs'] |
||
209 | self._train_costs.append(costs) |
||
210 | sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f" % (self._current_iter * 100 / self.num_train_batches, |
||
211 | costs[0])) |
||
212 | sys.stdout.flush() |
||
213 | elif 'get_num_batches_done' in req: |
||
214 | self.num_train_batches = req['get_num_batches_done'] |
||
215 | elif 'get_easgd_alpha' in req: |
||
216 | response = self._easgd_alpha |
||
217 | elif 'sync_hyperparams' in req: |
||
218 | response = {"sync_hyperparams": self.feed_hyperparams()} |
||
219 | elif 'init_schedule' in req: |
||
220 | with self._lock: |
||
221 | sys.stdout.write("\r") |
||
222 | sys.stdout.flush() |
||
223 | logging.info("worker {} connected".format(worker_id)) |
||
224 | if self.epoch == 0: |
||
225 | schedule_params = req['init_schedule'] |
||
226 | sch_str = " ".join("{}={}".format(a, b) for (a, b) in schedule_params.items()) |
||
227 | logging.info("initialize the schedule with {}".format(sch_str)) |
||
228 | for key, val in schedule_params.items(): |
||
229 | if not val: continue |
||
230 | if key == 'learning_rate': |
||
231 | self._lr = val |
||
232 | elif key == 'start_halving_at': |
||
233 | self.epoch_start_halving = val |
||
234 | elif key == 'end_at': |
||
235 | self.end_at = val |
||
236 | elif key == 'step_len': |
||
237 | self.step_len = val |
||
238 | elif key == 'valid_freq': |
||
239 | self._valid_freq = val |
||
240 | |||
241 | elif 'set_names' in req: |
||
242 | self._training_names = req['training_names'] |
||
243 | self._evaluation_names = req['evaluation_names'] |
||
244 | |||
245 | |||
246 | return response |
||
247 | |||
248 | if __name__ == '__main__': |
||
249 | ap = ArgumentParser() |
||
250 | ap.add_argument("--port", type=int, default=5567) |
||
251 | ap.add_argument("--easgd_alpha", type=float, default=0.5) |
||
252 | args = ap.parse_args() |
||
253 | |||
254 | server = ScheduledTrainingServer( |
||
255 | port=args.port, |
||
256 | easgd_alpha=args.easgd_alpha) |
||
257 | server.serve() |
||
258 |