| Conditions | 28 |
| Total Lines | 107 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 1 | ||
| Bugs | 0 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like ScheduledTrainingServer.handle_control() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python |
||
| 106 | def handle_control(self, req, worker_id): |
||
| 107 | """ |
||
| 108 | Handles a control_request received from a worker. |
||
| 109 | Returns: |
||
| 110 | string or dict: response |
||
| 111 | |||
| 112 | 'stop' - the worker should quit |
||
| 113 | 'wait' - wait for 1 second |
||
| 114 | 'eval' - evaluate on valid and test set to start a new epoch |
||
| 115 | 'sync_hyperparams' - set learning rate |
||
| 116 | 'valid' - evaluate on valid and test set, then save the params |
||
| 117 | 'train' - train next batches |
||
| 118 | """ |
||
| 119 | if self.start_time is None: self.start_time = time.time() |
||
| 120 | response = "" |
||
| 121 | |||
| 122 | if req == 'next': |
||
| 123 | if self.num_train_batches == 0: |
||
| 124 | response = "get_num_batches" |
||
| 125 | elif self._done: |
||
| 126 | response = "stop" |
||
| 127 | self.worker_is_done(worker_id) |
||
| 128 | elif self._evaluating: |
||
| 129 | response = 'wait' |
||
| 130 | elif not self.batch_pool: |
||
| 131 | # End of one iter |
||
| 132 | if self._train_costs: |
||
| 133 | with self._lock: |
||
| 134 | sys.stdout.write("\r") |
||
| 135 | sys.stdout.flush() |
||
| 136 | mean_costs = [] |
||
| 137 | for i in range(len(self._training_names)): |
||
| 138 | mean_costs.append(np.mean([c[i] for c in self._train_costs])) |
||
| 139 | logging.info("train (epoch={:2d}) {}".format( |
||
| 140 | self.epoch, |
||
| 141 | self.get_monitor_string(zip(self._training_names, mean_costs))) |
||
| 142 | ) |
||
| 143 | response = {'eval': None, 'best_valid_cost': self._best_valid_cost} |
||
| 144 | self._evaluating = True |
||
| 145 | else: |
||
| 146 | # Continue training |
||
| 147 | if worker_id not in self.prepared_worker_pool: |
||
| 148 | response = {"sync_hyperparams": self.feed_hyperparams()} |
||
| 149 | self.prepared_worker_pool.add(worker_id) |
||
| 150 | elif self._iters_from_last_valid >= self._valid_freq: |
||
| 151 | response = {'valid': None, 'best_valid_cost': self._best_valid_cost} |
||
| 152 | self._iters_from_last_valid = 0 |
||
| 153 | else: |
||
| 154 | response = {"train": self.feed_batches()} |
||
| 155 | elif 'eval_done' in req: |
||
| 156 | with self._lock: |
||
| 157 | self._evaluating = False |
||
| 158 | sys.stdout.write("\r") |
||
| 159 | sys.stdout.flush() |
||
| 160 | if 'test_costs' in req: |
||
| 161 | logging.info("test (epoch={:2d}) {}".format( |
||
| 162 | self.epoch, |
||
| 163 | self.get_monitor_string(req['test_costs'])) |
||
| 164 | ) |
||
| 165 | if 'valid_costs' in req: |
||
| 166 | valid_J = req['valid_costs'][0][1] |
||
| 167 | if valid_J < self._best_valid_cost: |
||
| 168 | self._best_valid_cost = valid_J |
||
| 169 | star_str = "*" |
||
| 170 | else: |
||
| 171 | star_str = "" |
||
| 172 | logging.info("valid (epoch={:2d}) {} {}".format( |
||
| 173 | self.epoch, |
||
| 174 | self.get_monitor_string(req['valid_costs']), |
||
| 175 | star_str)) |
||
| 176 | continue_training = self.prepare_epoch() |
||
| 177 | if not continue_training: |
||
| 178 | self._done = True |
||
| 179 | logging.info("training time {:.4f}s".format(time.time() - self.start_time)) |
||
| 180 | response = "stop" |
||
| 181 | elif 'valid_done' in req: |
||
| 182 | with self._lock: |
||
| 183 | sys.stdout.write("\r") |
||
| 184 | sys.stdout.flush() |
||
| 185 | if 'valid_costs' in req: |
||
| 186 | valid_J = req['valid_costs'][0][1] |
||
| 187 | if valid_J < self._best_valid_cost: |
||
| 188 | self._best_valid_cost = valid_J |
||
| 189 | star_str = "*" |
||
| 190 | else: |
||
| 191 | star_str = "" |
||
| 192 | logging.info("valid ( dryrun ) {} {}".format( |
||
| 193 | self.get_monitor_string(req['valid_costs']), |
||
| 194 | star_str |
||
| 195 | )) |
||
| 196 | elif 'train_done' in req: |
||
| 197 | costs = req['costs'] |
||
| 198 | self._train_costs.append(costs) |
||
| 199 | sys.stdout.write("\x1b[2K\r> %d%% | J=%.2f" % (self._current_iter * 100 / self.num_train_batches, |
||
| 200 | costs[0])) |
||
| 201 | sys.stdout.flush() |
||
| 202 | elif 'get_num_batches_done' in req: |
||
| 203 | self.num_train_batches = req['get_num_batches_done'] |
||
| 204 | elif 'get_easgd_alpha' in req: |
||
| 205 | response = self._easgd_alpha |
||
| 206 | elif 'sync_hyperparams' in req: |
||
| 207 | response = {"sync_hyperparams": self.feed_hyperparams()} |
||
| 208 | elif 'set_names' in req: |
||
| 209 | self._training_names = req['training_names'] |
||
| 210 | self._evaluation_names = req['evaluation_names'] |
||
| 211 | |||
| 212 | return response |
||
| 213 | |||
| 233 |