| Conditions | 25 | 
| Total Lines | 92 | 
| Lines | 92 | 
| Ratio | 100 % | 
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like AttentionTrainer.train_func() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python  | 
            ||
| 63 | View Code Duplication | def train_func(self, train_set):  | 
            |
| 64 | cost_sum = 0.0  | 
            ||
| 65 | batch_cost = 0.0  | 
            ||
| 66 | counter = 0  | 
            ||
| 67 | total = 0  | 
            ||
| 68 | total_reward = 0  | 
            ||
| 69 | batch_reward = 0  | 
            ||
| 70 | total_position_value = 0  | 
            ||
| 71 | pena_count = 0  | 
            ||
| 72 | for d in train_set:  | 
            ||
| 73 | pairs = self.grad_func(*d)  | 
            ||
| 74 | cost = pairs[0]  | 
            ||
| 75 | if cost > 10 or np.isnan(cost):  | 
            ||
| 76 |                 sys.stdout.write("X") | 
            ||
| 77 | sys.stdout.flush()  | 
            ||
| 78 | continue  | 
            ||
| 79 | batch_cost += cost  | 
            ||
| 80 | |||
| 81 | wl_grad = pairs[1]  | 
            ||
| 82 | max_position_value = np.max(np.absolute(pairs[2]))  | 
            ||
| 83 | total_position_value += max_position_value  | 
            ||
| 84 | last_decision = pairs[3]  | 
            ||
| 85 | target_decision = d[1][0]  | 
            ||
| 86 | reward = 0.005 if last_decision == target_decision else 0  | 
            ||
| 87 | if max_position_value > 0.8:  | 
            ||
| 88 | reward = 0  | 
            ||
| 89 | total_reward += reward  | 
            ||
| 90 | batch_reward += reward  | 
            ||
| 91 | if self.last_average_reward == 999 and total > 2000:  | 
            ||
| 92 | self.last_average_reward = total_reward / total  | 
            ||
| 93 | if not self.disable_reinforce:  | 
            ||
| 94 | self.batch_wl_grad += wl_grad * - (reward - self.last_average_reward)  | 
            ||
| 95 | if not self.disable_backprop:  | 
            ||
| 96 | for grad_cache, grad in zip(self.batch_grad, pairs[4:]):  | 
            ||
| 97 | grad_cache += grad  | 
            ||
| 98 | counter += 1  | 
            ||
| 99 | total += 1  | 
            ||
| 100 | if counter >= self.batch_size:  | 
            ||
| 101 | if total == counter: counter -= 1  | 
            ||
| 102 | self.update_parameters(self.last_average_reward < 999)  | 
            ||
| 103 | |||
| 104 | # Clean batch gradients  | 
            ||
| 105 | if not self.disable_reinforce:  | 
            ||
| 106 | self.batch_wl_grad *= 0  | 
            ||
| 107 | if not self.disable_backprop:  | 
            ||
| 108 | for grad_cache in self.batch_grad:  | 
            ||
| 109 | grad_cache *= 0  | 
            ||
| 110 | |||
| 111 | if total % 1000 == 0:  | 
            ||
| 112 |                     sys.stdout.write(".") | 
            ||
| 113 | sys.stdout.flush()  | 
            ||
| 114 | |||
| 115 | # Cov  | 
            ||
| 116 | if not self.disable_reinforce:  | 
            ||
| 117 | cov_changed = False  | 
            ||
| 118 | if batch_reward / self.batch_size < 0.001:  | 
            ||
| 119 | if not self.large_cov_mode:  | 
            ||
| 120 | if pena_count > 20:  | 
            ||
| 121 | self.layer.cov.set_value(self.layer.large_cov)  | 
            ||
| 122 | print "[LCOV]",  | 
            ||
| 123 | cov_changed = True  | 
            ||
| 124 | else:  | 
            ||
| 125 | pena_count += 1  | 
            ||
| 126 | else:  | 
            ||
| 127 | pena_count = 0  | 
            ||
| 128 | else:  | 
            ||
| 129 | if self.large_cov_mode:  | 
            ||
| 130 | if pena_count > 20:  | 
            ||
| 131 | self.layer.cov.set_value(self.layer.small_cov)  | 
            ||
| 132 | print "[SCOV]",  | 
            ||
| 133 | cov_changed = True  | 
            ||
| 134 | else:  | 
            ||
| 135 | pena_count += 1  | 
            ||
| 136 | else:  | 
            ||
| 137 | pena_count = 0  | 
            ||
| 138 | if cov_changed:  | 
            ||
| 139 | self.large_cov_mode = not self.large_cov_mode  | 
            ||
| 140 | self.layer.cov_inv_var.set_value(np.array(LA.inv(self.layer.cov.get_value()), dtype=FLOATX))  | 
            ||
| 141 | self.layer.cov_det_var.set_value(LA.det(self.layer.cov.get_value()))  | 
            ||
| 142 | |||
| 143 | # Clean batch cost  | 
            ||
| 144 | counter = 0  | 
            ||
| 145 | cost_sum += batch_cost  | 
            ||
| 146 | batch_cost = 0.0  | 
            ||
| 147 | batch_reward = 0  | 
            ||
| 148 | if total == 0:  | 
            ||
| 149 | return "COST OVERFLOW"  | 
            ||
| 150 | |||
| 151 |         sys.stdout.write("\n") | 
            ||
| 152 | self.last_average_reward = (total_reward / total)  | 
            ||
| 153 | self.turn += 1  | 
            ||
| 154 | return "J: %.2f, Avg R: %.4f, Avg P: %.2f" % ((cost_sum / total), self.last_average_reward, (total_position_value / total))  | 
            ||
| 155 | |||
| 156 |