Conditions | 14 |
Total Lines | 84 |
Lines | 39 |
Ratio | 46.43 % |
Changes | 2 | ||
Bugs | 0 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like LSTM.fit() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | import os |
||
63 | def fit(self, x, y, batch_size=100, iter_num=100, summaries_dir=None, summary_interval=10, |
||
64 | test_x=None, test_y=None, session=None, criterion='const_iteration'): |
||
65 | """Fit the model to the dataset |
||
66 | |||
67 | Args: |
||
68 | x (:obj:`numpy.ndarray`): Input features of shape (num_samples, num_features). |
||
69 | y (:obj:`numpy.ndarray`): Corresponding Labels of shape (num_samples) for binary classification, |
||
70 | or (num_samples, num_classes) for multi-class classification. |
||
71 | batch_size (:obj:`int`): Batch size used in gradient descent. |
||
72 | iter_num (:obj:`int`): Number of training iterations for const iterations, step depth for monitor based |
||
73 | stopping criterion. |
||
74 | summaries_dir (:obj:`str`): Path of the directory to store summaries and saved values. |
||
75 | summary_interval (:obj:`int`): The step interval to export variable summaries. |
||
76 | test_x (:obj:`numpy.ndarray`): Test feature array used for monitoring training progress. |
||
77 | test_y (:obj:`numpy.ndarray): Test label array used for monitoring training progress. |
||
78 | session (:obj:`tensorflow.Session`): Session to run training functions. |
||
79 | criterion (:obj:`str`): Stopping criteria. 'const_iterations' or 'monitor_based' |
||
80 | """ |
||
81 | if session is None: |
||
82 | if self.sess is None: |
||
83 | session = tf.Session() |
||
84 | self.sess = session |
||
85 | else: |
||
86 | session = self.sess |
||
87 | if summaries_dir is not None: |
||
88 | train_writer = tf.summary.FileWriter(summaries_dir + '/train') |
||
89 | test_writer = tf.summary.FileWriter(summaries_dir + '/test') |
||
90 | session.run(tf.global_variables_initializer()) |
||
91 | # Setup batch injector |
||
92 | injector = BatchSequenceInjector(data_x=x, data_y=y, batch_size=batch_size, seq_len=self.num_steps) |
||
93 | # Get Stopping Criterion |
||
94 | if criterion == 'const_iteration': |
||
95 | _criterion = ConstIterations(num_iters=iter_num) |
||
96 | elif criterion == 'monitor_based': |
||
97 | num_samples = x.shape[0] |
||
98 | valid_set_len = int(1/5 * num_samples) |
||
99 | valid_x = x[num_samples-valid_set_len:num_samples, :] |
||
100 | valid_y = y[num_samples-valid_set_len:num_samples, :] |
||
101 | x = x[0:num_samples-valid_set_len, :] |
||
102 | y = y[0:num_samples-valid_set_len, :] |
||
103 | _criterion = MonitorBased(n_steps=iter_num, |
||
104 | monitor_fn=self.predict_accuracy, |
||
105 | monitor_fn_args=(valid_x, valid_y[self.num_steps:, :]), |
||
106 | save_fn=tf.train.Saver().save, |
||
107 | View Code Duplication | save_fn_args=(session, summaries_dir + '/best.ckpt')) |
|
|
|||
108 | else: |
||
109 | logger.error('Wrong criterion %s specified.' % criterion) |
||
110 | return |
||
111 | # Train/Test sequence for brief reporting of accuracy and loss |
||
112 | train_seq_x, train_seq_y = BatchSequenceInjector.to_sequence( |
||
113 | self.num_steps, x, y, start=0, end=2000 |
||
114 | ) |
||
115 | if (test_x is not None) and (test_y is not None): |
||
116 | test_seq_x, test_seq_y = BatchSequenceInjector.to_sequence( |
||
117 | self.num_steps, test_x, test_y, start=0, end=2000 |
||
118 | ) |
||
119 | # Iteration Starts |
||
120 | i = 0 |
||
121 | while _criterion.continue_learning(): |
||
122 | batch_x, batch_y = injector.next_batch() |
||
123 | if summaries_dir is not None and (i % summary_interval == 0): |
||
124 | summary, loss, accuracy = session.run( |
||
125 | [self.merged, self.loss, self.accuracy], |
||
126 | feed_dict={self.x: train_seq_x, self.y_: train_seq_y, |
||
127 | self.init_state: np.zeros((train_seq_x.shape[0], 2 * self.num_units))} |
||
128 | ) |
||
129 | train_writer.add_summary(summary, i) |
||
130 | View Code Duplication | logger.info('Step %d, train_set accuracy %g, loss %g' % (i, accuracy, loss)) |
|
131 | if (test_x is not None) and (test_y is not None): |
||
132 | merged, accuracy = session.run( |
||
133 | [self.merged, self.accuracy], |
||
134 | feed_dict={self.x: test_seq_x, self.y_: test_seq_y, |
||
135 | self.init_state: np.zeros((test_seq_x.shape[0], 2*self.num_units))}) |
||
136 | test_writer.add_summary(merged, i) |
||
137 | logger.info('test_set accuracy %g' % accuracy) |
||
138 | loss, accuracy, _ = session.run( |
||
139 | [self.loss, self.accuracy, self.fit_step], |
||
140 | feed_dict={self.x: batch_x, self.y_: batch_y, |
||
141 | self.init_state: np.zeros((batch_x.shape[0], 2 * self.num_units))}) |
||
142 | i += 1 |
||
143 | # Finish Iteration |
||
144 | if criterion == 'monitor_based': |
||
145 | tf.train.Saver().restore(session, os.path.join(summaries_dir, 'best.ckpt')) |
||
146 | logger.debug('Total Epoch: %d, current batch %d', injector.num_epochs, injector.cur_batch) |
||
147 | |||
204 |