Conditions | 42 |
Total Lines | 147 |
Lines | 0 |
Ratio | 0 % |
Changes | 4 | ||
Bugs | 0 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like ScheduledTrainingServer.handle_control() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | #!/usr/bin/env python |
||
116 | def handle_control(self, req, worker_id, req_info): |
||
117 | """ |
||
118 | Handles a control_request received from a worker. |
||
119 | Returns: |
||
120 | string or dict: response |
||
121 | |||
122 | 'stop' - the worker should quit |
||
123 | 'wait' - wait for 1 second |
||
124 | 'eval' - evaluate on valid and test set to start a new epoch |
||
125 | 'sync_hyperparams' - set learning rate |
||
126 | 'valid' - evaluate on valid and test set, then save the params |
||
127 | 'train' - train next batches |
||
128 | """ |
||
129 | if self.start_time is None: self.start_time = time.time() |
||
130 | response = "" |
||
131 | |||
132 | if req == 'next': |
||
133 | if self.num_train_batches == 0: |
||
134 | response = "get_num_batches" |
||
135 | elif self._done: |
||
136 | response = "stop" |
||
137 | self.worker_is_done(worker_id) |
||
138 | elif self._evaluating: |
||
139 | response = 'wait' |
||
140 | elif not self.batch_pool: |
||
141 | # End of one iter |
||
142 | if self._train_costs: |
||
143 | with self._lock: |
||
144 | sys.stdout.write("\r") |
||
145 | sys.stdout.flush() |
||
146 | mean_costs = [] |
||
147 | for i in range(len(self._training_names)): |
||
148 | mean_costs.append(np.mean([c[i] for c in self._train_costs])) |
||
149 | self.log("train (epoch={:2d}) {}".format( |
||
150 | self.epoch, |
||
151 | self.get_monitor_string(zip(self._training_names, mean_costs))) |
||
152 | ) |
||
153 | response = {'eval': None, 'best_valid_cost': self._best_valid_cost} |
||
154 | self._evaluating = True |
||
155 | else: |
||
156 | # Continue training |
||
157 | if worker_id not in self.prepared_worker_pool: |
||
158 | response = {"sync_hyperparams": self.feed_hyperparams()} |
||
159 | self.prepared_worker_pool.add(worker_id) |
||
160 | elif self._iters_from_last_valid >= self._valid_freq: |
||
161 | response = {'valid': None, 'best_valid_cost': self._best_valid_cost} |
||
162 | self._iters_from_last_valid = 0 |
||
163 | else: |
||
164 | response = {"train": self.feed_batches()} |
||
165 | elif 'eval_done' in req: |
||
166 | with self._lock: |
||
167 | self._evaluating = False |
||
168 | sys.stdout.write("\r") |
||
169 | sys.stdout.flush() |
||
170 | if 'test_costs' in req and req['test_costs']: |
||
171 | self.log("test (epoch={:2d}) {} (worker {})".format( |
||
172 | self.epoch, |
||
173 | self.get_monitor_string(req['test_costs']), |
||
174 | worker_id) |
||
175 | ) |
||
176 | if 'valid_costs' in req and req['test_costs']: |
||
177 | valid_J = req['valid_costs'][0][1] |
||
178 | if valid_J < self._best_valid_cost: |
||
179 | self._best_valid_cost = valid_J |
||
180 | star_str = "*" |
||
181 | else: |
||
182 | star_str = "" |
||
183 | self.log("valid (epoch={:2d}) {} {} (worker {})".format( |
||
184 | self.epoch, |
||
185 | self.get_monitor_string(req['valid_costs']), |
||
186 | star_str, |
||
187 | worker_id)) |
||
188 | # if star_str and 'auto_save' in req and req['auto_save']: |
||
189 | # self.log("(worker {}) save the model to {}".format( |
||
190 | # worker_id, |
||
191 | # req['auto_save'] |
||
192 | # )) |
||
193 | continue_training = self.prepare_epoch() |
||
194 | self._epoch_start_time = time.time() |
||
195 | if not continue_training: |
||
196 | self._done = True |
||
197 | self.log("training time {:.4f}s".format(time.time() - self.start_time)) |
||
198 | response = "stop" |
||
199 | elif 'valid_done' in req: |
||
200 | with self._lock: |
||
201 | sys.stdout.write("\r") |
||
202 | sys.stdout.flush() |
||
203 | if 'valid_costs' in req: |
||
204 | valid_J = req['valid_costs'][0][1] |
||
205 | if valid_J < self._best_valid_cost: |
||
206 | self._best_valid_cost = valid_J |
||
207 | star_str = "*" |
||
208 | else: |
||
209 | star_str = "" |
||
210 | self.log("valid ( dryrun ) {} {} (worker {})".format( |
||
211 | self.get_monitor_string(req['valid_costs']), |
||
212 | star_str, |
||
213 | worker_id |
||
214 | )) |
||
215 | # if star_str and 'auto_save' in req and req['auto_save']: |
||
216 | # self.log("(worker {}) save the model to {}".format( |
||
217 | # worker_id, |
||
218 | # req['auto_save'] |
||
219 | # )) |
||
220 | elif 'train_done' in req: |
||
221 | costs = req['costs'] |
||
222 | self._train_costs.append(costs) |
||
223 | sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f | %.1f batch/s" % ( |
||
224 | self._current_iter * 100 / self.num_train_batches, |
||
225 | costs[0], float(len(self._train_costs) * self.sync_freq) / (time.time() - self._epoch_start_time))) |
||
226 | sys.stdout.flush() |
||
227 | elif 'get_num_batches_done' in req: |
||
228 | self.num_train_batches = req['get_num_batches_done'] |
||
229 | elif 'get_easgd_alpha' in req: |
||
230 | response = self._easgd_alpha |
||
231 | elif 'sync_hyperparams' in req: |
||
232 | response = {"sync_hyperparams": self.feed_hyperparams()} |
||
233 | elif 'init_schedule' in req: |
||
234 | with self._lock: |
||
235 | sys.stdout.write("\r") |
||
236 | sys.stdout.flush() |
||
237 | self.log("worker {} connected".format(worker_id)) |
||
238 | if self.epoch == 0: |
||
239 | schedule_params = req['init_schedule'] |
||
240 | sch_str = " ".join("{}={}".format(a, b) for (a, b) in schedule_params.items()) |
||
241 | self.log("initialize the schedule with {}".format(sch_str)) |
||
242 | for key, val in schedule_params.items(): |
||
243 | if not val: continue |
||
244 | if key == 'learning_rate': |
||
245 | self._lr = val |
||
246 | elif key == 'start_halving_at': |
||
247 | self.epoch_start_halving = val |
||
248 | elif key == 'halving_freq': |
||
249 | self._halving_freq = val |
||
250 | elif key == 'end_at': |
||
251 | self.end_at = val |
||
252 | elif key == 'sync_freq': |
||
253 | self.sync_freq = val |
||
254 | elif key == 'valid_freq': |
||
255 | self._valid_freq = val |
||
256 | |||
257 | elif 'set_names' in req: |
||
258 | self._training_names = req['training_names'] |
||
259 | self._evaluation_names = req['evaluation_names'] |
||
260 | |||
261 | |||
262 | return response |
||
263 | |||
278 |