| Conditions | 25 |
| Total Lines | 92 |
| Lines | 0 |
| Ratio | 0 % |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like experiments.attention_models.AttentionTrainer.train_func() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python |
||
| 63 | def train_func(self, train_set): |
||
| 64 | cost_sum = 0.0 |
||
| 65 | batch_cost = 0.0 |
||
| 66 | counter = 0 |
||
| 67 | total = 0 |
||
| 68 | total_reward = 0 |
||
| 69 | batch_reward = 0 |
||
| 70 | total_position_value = 0 |
||
| 71 | pena_count = 0 |
||
| 72 | for d in train_set: |
||
| 73 | pairs = self.grad_func(*d) |
||
| 74 | cost = pairs[0] |
||
| 75 | if cost > 10 or np.isnan(cost): |
||
| 76 | sys.stdout.write("X") |
||
| 77 | sys.stdout.flush() |
||
| 78 | continue |
||
| 79 | batch_cost += cost |
||
| 80 | |||
| 81 | wl_grad = pairs[1] |
||
| 82 | max_position_value = np.max(np.absolute(pairs[2])) |
||
| 83 | total_position_value += max_position_value |
||
| 84 | last_decision = pairs[3] |
||
| 85 | target_decision = d[1][0] |
||
| 86 | reward = 0.005 if last_decision == target_decision else 0 |
||
| 87 | if max_position_value > 0.8: |
||
| 88 | reward = 0 |
||
| 89 | total_reward += reward |
||
| 90 | batch_reward += reward |
||
| 91 | if self.last_average_reward == 999 and total > 2000: |
||
| 92 | self.last_average_reward = total_reward / total |
||
| 93 | if not self.disable_reinforce: |
||
| 94 | self.batch_wl_grad += wl_grad * - (reward - self.last_average_reward) |
||
| 95 | if not self.disable_backprop: |
||
| 96 | for grad_cache, grad in zip(self.batch_grad, pairs[4:]): |
||
| 97 | grad_cache += grad |
||
| 98 | counter += 1 |
||
| 99 | total += 1 |
||
| 100 | if counter >= self.batch_size: |
||
| 101 | if total == counter: counter -= 1 |
||
| 102 | self.update_parameters(self.last_average_reward < 999) |
||
| 103 | |||
| 104 | # Clean batch gradients |
||
| 105 | if not self.disable_reinforce: |
||
| 106 | self.batch_wl_grad *= 0 |
||
| 107 | if not self.disable_backprop: |
||
| 108 | for grad_cache in self.batch_grad: |
||
| 109 | grad_cache *= 0 |
||
| 110 | |||
| 111 | if total % 1000 == 0: |
||
| 112 | sys.stdout.write(".") |
||
| 113 | sys.stdout.flush() |
||
| 114 | |||
| 115 | # Cov |
||
| 116 | if not self.disable_reinforce: |
||
|
|
|||
| 117 | cov_changed = False |
||
| 118 | if batch_reward / self.batch_size < 0.001: |
||
| 119 | if not self.large_cov_mode: |
||
| 120 | if pena_count > 20: |
||
| 121 | self.layer.cov.set_value(self.layer.large_cov) |
||
| 122 | print "[LCOV]", |
||
| 123 | cov_changed = True |
||
| 124 | else: |
||
| 125 | pena_count += 1 |
||
| 126 | else: |
||
| 127 | pena_count = 0 |
||
| 128 | else: |
||
| 129 | if self.large_cov_mode: |
||
| 130 | if pena_count > 20: |
||
| 131 | self.layer.cov.set_value(self.layer.small_cov) |
||
| 132 | print "[SCOV]", |
||
| 133 | cov_changed = True |
||
| 134 | else: |
||
| 135 | pena_count += 1 |
||
| 136 | else: |
||
| 137 | pena_count = 0 |
||
| 138 | if cov_changed: |
||
| 139 | self.large_cov_mode = not self.large_cov_mode |
||
| 140 | self.layer.cov_inv_var.set_value(np.array(LA.inv(self.layer.cov.get_value()), dtype=FLOATX)) |
||
| 141 | self.layer.cov_det_var.set_value(LA.det(self.layer.cov.get_value())) |
||
| 142 | |||
| 143 | # Clean batch cost |
||
| 144 | counter = 0 |
||
| 145 | cost_sum += batch_cost |
||
| 146 | batch_cost = 0.0 |
||
| 147 | batch_reward = 0 |
||
| 148 | if total == 0: |
||
| 149 | return "COST OVERFLOW" |
||
| 150 | |||
| 151 | sys.stdout.write("\n") |
||
| 152 | self.last_average_reward = (total_reward / total) |
||
| 153 | self.turn += 1 |
||
| 154 | return "J: %.2f, Avg R: %.4f, Avg P: %.2f" % ((cost_sum / total), self.last_average_reward, (total_position_value / total)) |
||
| 155 | |||
| 156 |
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.
You can also find more detailed suggestions in the “Code” section of your repository.