ScheduledTrainingServer   C
last analyzed

Complexity

Total Complexity 56

Size/Duplication

Total Lines 244
Duplicated Lines 0 %

Importance

Changes 4
Bugs 0 Features 0
Metric Value
c 4
b 0
f 0
dl 0
loc 244
rs 5.5555
wmc 56

7 Methods

Rating   Name   Duplication   Size   Complexity  
B __init__() 0 41 3
A get_monitor_string() 0 2 2
A feed_hyperparams() 0 7 1
A feed_batches() 0 9 2
F handle_control() 0 147 42
A prepare_epoch() 0 21 4
A log() 0 4 2

How to fix   Complexity   

Complex Class

Complex classes like ScheduledTrainingServer often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
from __future__ import print_function
5
import time
6
import sys, os
7
import numpy as np
8
import logging as loggers
9
from threading import Lock
10
logging = loggers.getLogger("ScheduledTrainingServer")
11
loggers.basicConfig(level=loggers.INFO)
12
13
from platoon.channel import Controller
14
from argparse import ArgumentParser
15
16
17
CONTROLLER_PORT = 5567
18
19
class ScheduledTrainingServer(Controller):
20
    """
21
    This multi-process controller implements patience-based early-stopping SGD
22
    """
23
24
    def __init__(self, port=CONTROLLER_PORT, easgd_alpha=0.5,
25
                 # Following arguments can be received from workers
26
                 start_halving_at=6, end_at=10, sync_freq=10, halving_freq=1,
27
                 valid_freq=1500, learning_rate=0.1, log_path=None):
28
        """
29
        Initialize the controller.
30
31
        Args:
32
            port (int): batches in one training step
33
            easgd_alpha (float)
34
        """
35
36
        Controller.__init__(self, port)
37
        self.epoch_start_halving = start_halving_at
38
        self.end_at = end_at
39
        self.sync_freq = sync_freq
40
        self.start_time = None
41
        self.rand = np.random.RandomState(3)
42
        self.epoch = 0
43
        self._current_iter = 0
44
        self._iters_from_last_valid = 0
45
        self._evaluating = False
46
        self._valid_freq = valid_freq
47
        self._halving_freq = halving_freq
48
        self._done = False
49
        self._lr = learning_rate
50
        self._easgd_alpha = easgd_alpha
51
        self._training_names = []
52
        self._evaluation_names = []
53
        self._best_valid_cost = sys.float_info.max
54
        self._lock = Lock()
55
56
        self.num_train_batches = 0
57
        self.batch_pool = []
58
        self._train_costs = []
59
        self._epoch_start_time = None
60
        self.prepared_worker_pool = set()
61
        self.log_file = open(log_path, "w") if log_path else None
62
        if log_path:
63
            logging.info("write logs into {}".format(log_path))
64
        logging.info("multi-gpu server is listening port {}".format(port))
65
66
    def log(self, msg):
67
        logging.info(msg)
68
        if self.log_file:
69
            self.log_file.write(msg + "\n")
70
71
    def prepare_epoch(self):
72
        """
73
        Prepare for one epoch.
74
        Returns:
75
            bool: False if to stop the training.
76
        """
77
        self.epoch += 1
78
        if self.epoch >= self.epoch_start_halving and ((self.epoch - self.epoch_start_halving) % self._halving_freq == 0):
79
            self._lr *= 0.5
80
        self._current_iter = 0
81
        self._iters_from_last_valid = 0
82
        self._train_costs = []
83
        self.prepared_worker_pool.clear()
84
        self.batch_pool = range(self.num_train_batches)
85
        self.rand.shuffle(self.batch_pool)
86
        if self.epoch > self.end_at:
87
            self.log("Training is done, wait all workers to stop")
88
            return False
89
        else:
90
            self.log("start epoch {} with lr={}".format(self.epoch, self._lr))
91
            return True
92
93
    def feed_batches(self):
94
        if not self.batch_pool:
95
            return None
96
        else:
97
            batches = self.batch_pool[:self.sync_freq]
98
            self.batch_pool = self.batch_pool[self.sync_freq:]
99
            self._current_iter += len(batches)
100
            self._iters_from_last_valid += len(batches)
101
            return batches
102
103
104
    def feed_hyperparams(self):
105
        retval = {
106
            "epoch": self.epoch,
107
            "learning_rate": self._lr,
108
            "easgd_alpha": self._easgd_alpha
109
        }
110
        return retval
111
112
    def get_monitor_string(self, costs):
113
        return " ".join(["{}={:.2f}".format(n, c) for (n, c) in costs])
114
115
116
    def handle_control(self, req, worker_id, req_info):
117
        """
118
        Handles a control_request received from a worker.
119
        Returns:
120
            string or dict: response
121
122
            'stop' - the worker should quit
123
            'wait' - wait for 1 second
124
            'eval' - evaluate on valid and test set to start a new epoch
125
            'sync_hyperparams' - set learning rate
126
            'valid' - evaluate on valid and test set, then save the params
127
            'train' - train next batches
128
        """
129
        if self.start_time is None: self.start_time = time.time()
130
        response = ""
131
132
        if req == 'next':
133
            if self.num_train_batches == 0:
134
                response = "get_num_batches"
135
            elif self._done:
136
                response = "stop"
137
                self.worker_is_done(worker_id)
138
            elif self._evaluating:
139
                response = 'wait'
140
            elif not self.batch_pool:
141
                # End of one iter
142
                if self._train_costs:
143
                    with self._lock:
144
                        sys.stdout.write("\r")
145
                        sys.stdout.flush()
146
                        mean_costs = []
147
                        for i in range(len(self._training_names)):
148
                            mean_costs.append(np.mean([c[i] for c in self._train_costs]))
149
                        self.log("train   (epoch={:2d}) {}".format(
150
                            self.epoch,
151
                            self.get_monitor_string(zip(self._training_names, mean_costs)))
152
                        )
153
                response = {'eval': None, 'best_valid_cost': self._best_valid_cost}
154
                self._evaluating = True
155
            else:
156
                # Continue training
157
                if worker_id not in self.prepared_worker_pool:
158
                    response = {"sync_hyperparams": self.feed_hyperparams()}
159
                    self.prepared_worker_pool.add(worker_id)
160
                elif self._iters_from_last_valid >= self._valid_freq:
161
                    response = {'valid': None, 'best_valid_cost': self._best_valid_cost}
162
                    self._iters_from_last_valid = 0
163
                else:
164
                    response = {"train": self.feed_batches()}
165
        elif 'eval_done' in req:
166
            with self._lock:
167
                self._evaluating = False
168
                sys.stdout.write("\r")
169
                sys.stdout.flush()
170
                if 'test_costs' in req and req['test_costs']:
171
                    self.log("test    (epoch={:2d}) {} (worker {})".format(
172
                        self.epoch,
173
                        self.get_monitor_string(req['test_costs']),
174
                        worker_id)
175
                    )
176
                if 'valid_costs' in req and req['test_costs']:
177
                    valid_J = req['valid_costs'][0][1]
178
                    if valid_J < self._best_valid_cost:
179
                        self._best_valid_cost = valid_J
180
                        star_str = "*"
181
                    else:
182
                        star_str = ""
183
                        self.log("valid   (epoch={:2d}) {} {} (worker {})".format(
184
                        self.epoch,
185
                        self.get_monitor_string(req['valid_costs']),
186
                        star_str,
187
                        worker_id))
188
                    # if star_str and 'auto_save' in req and req['auto_save']:
189
                    #     self.log("(worker {}) save the model to {}".format(
190
                    #         worker_id,
191
                    #         req['auto_save']
192
                    #     ))
193
                continue_training = self.prepare_epoch()
194
                self._epoch_start_time = time.time()
195
                if not continue_training:
196
                    self._done = True
197
                    self.log("training time {:.4f}s".format(time.time() - self.start_time))
198
                    response = "stop"
199
        elif 'valid_done' in req:
200
            with self._lock:
201
                sys.stdout.write("\r")
202
                sys.stdout.flush()
203
                if 'valid_costs' in req:
204
                    valid_J = req['valid_costs'][0][1]
205
                    if valid_J < self._best_valid_cost:
206
                        self._best_valid_cost = valid_J
207
                        star_str = "*"
208
                    else:
209
                        star_str = ""
210
                    self.log("valid   ( dryrun ) {} {} (worker {})".format(
211
                        self.get_monitor_string(req['valid_costs']),
212
                        star_str,
213
                        worker_id
214
                    ))
215
                    # if star_str and 'auto_save' in req and req['auto_save']:
216
                    #     self.log("(worker {}) save the model to {}".format(
217
                    #         worker_id,
218
                    #         req['auto_save']
219
                    #     ))
220
        elif 'train_done' in req:
221
            costs = req['costs']
222
            self._train_costs.append(costs)
223
            sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f | %.1f batch/s" % (
224
                self._current_iter * 100 / self.num_train_batches,
225
                costs[0], float(len(self._train_costs) * self.sync_freq) / (time.time() - self._epoch_start_time)))
226
            sys.stdout.flush()
227
        elif 'get_num_batches_done' in req:
228
            self.num_train_batches = req['get_num_batches_done']
229
        elif 'get_easgd_alpha' in req:
230
            response = self._easgd_alpha
231
        elif 'sync_hyperparams' in req:
232
            response = {"sync_hyperparams": self.feed_hyperparams()}
233
        elif 'init_schedule' in req:
234
            with self._lock:
235
                sys.stdout.write("\r")
236
                sys.stdout.flush()
237
                self.log("worker {} connected".format(worker_id))
238
                if self.epoch == 0:
239
                    schedule_params = req['init_schedule']
240
                    sch_str = " ".join("{}={}".format(a, b) for (a, b) in schedule_params.items())
241
                    self.log("initialize the schedule with {}".format(sch_str))
242
                    for key, val in schedule_params.items():
243
                        if not val: continue
244
                        if key == 'learning_rate':
245
                            self._lr = val
246
                        elif key == 'start_halving_at':
247
                            self.epoch_start_halving = val
248
                        elif key == 'halving_freq':
249
                            self._halving_freq = val
250
                        elif key == 'end_at':
251
                            self.end_at = val
252
                        elif key == 'sync_freq':
253
                            self.sync_freq = val
254
                        elif key == 'valid_freq':
255
                            self._valid_freq = val
256
257
        elif 'set_names' in req:
258
            self._training_names = req['training_names']
259
            self._evaluation_names = req['evaluation_names']
260
261
262
        return response
263
264
if __name__ == '__main__':
265
    ap = ArgumentParser()
266
    ap.add_argument("--port", type=int, default=5567)
267
    ap.add_argument("--easgd_alpha", type=float, default=0.5)
268
    ap.add_argument("--log", type=str, default=None)
269
    args = ap.parse_args()
270
271
    server = ScheduledTrainingServer(
272
        port=args.port,
273
        easgd_alpha=args.easgd_alpha,
274
        log_path=args.log
275
    )
276
277
    server.serve()
278