Conditions | 12 |
Total Lines | 72 |
Code Lines | 58 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like e2edutch.train.main() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | #!/usr/bin/env python |
||
40 | def main(args=None): |
||
41 | args = get_parser().parse_args() |
||
42 | if args.verbose: |
||
43 | logging.basicConfig(level=logging.DEBUG) |
||
44 | config = util.initialize_from_env(args.config, |
||
45 | args.cfg_file, |
||
46 | args.model_cfg_file) |
||
47 | # Overwrite train and eval file if specified |
||
48 | if args.train is not None: |
||
49 | config['train_path'] = args.train |
||
50 | if args.eval is not None: |
||
51 | config['eval_path'] = args.eval |
||
52 | if args.eval_conll is not None: |
||
53 | config['conll_eval_path'] = args.eval_conll |
||
54 | |||
55 | report_frequency = config["report_frequency"] |
||
56 | eval_frequency = config["eval_frequency"] |
||
57 | |||
58 | model = cm.CorefModel(config) |
||
59 | saver = tf.train.Saver() |
||
60 | |||
61 | log_dir = os.path.join(config['log_root'], config['log_dir']) |
||
62 | writer = tf.summary.FileWriter(log_dir, flush_secs=20) |
||
63 | |||
64 | max_f1 = 0 |
||
65 | |||
66 | with tf.Session() as session: |
||
67 | session.run(tf.global_variables_initializer()) |
||
68 | model.start_enqueue_thread(session) |
||
69 | accumulated_loss = 0.0 |
||
70 | |||
71 | ckpt = tf.train.get_checkpoint_state(log_dir) |
||
72 | if ckpt and ckpt.model_checkpoint_path: |
||
73 | print("Restoring from: {}".format(ckpt.model_checkpoint_path)) |
||
74 | saver.restore(session, ckpt.model_checkpoint_path) |
||
75 | |||
76 | initial_time = time.time() |
||
77 | while True: |
||
78 | tf_loss, tf_global_step, _ = session.run( |
||
79 | [model.loss, model.global_step, model.train_op]) |
||
80 | accumulated_loss += tf_loss |
||
81 | |||
82 | if tf_global_step % report_frequency == 0: |
||
83 | total_time = time.time() - initial_time |
||
84 | steps_per_second = tf_global_step / total_time |
||
85 | |||
86 | average_loss = accumulated_loss / report_frequency |
||
87 | print("[{}] loss={:.2f}, steps/s={:.2f}" |
||
88 | .format(tf_global_step, |
||
89 | average_loss, |
||
90 | steps_per_second)) |
||
91 | writer.add_summary(util.make_summary( |
||
92 | {"loss": average_loss}), tf_global_step) |
||
93 | accumulated_loss = 0.0 |
||
94 | |||
95 | if tf_global_step % eval_frequency == 0: |
||
96 | saver.save(session, os.path.join(log_dir, "model"), |
||
97 | global_step=tf_global_step) |
||
98 | eval_summary, eval_f1 = model.evaluate(session) |
||
99 | |||
100 | if eval_f1 > max_f1: |
||
101 | max_f1 = eval_f1 |
||
102 | util.copy_checkpoint(os.path.join( |
||
103 | log_dir, "model-{}".format(tf_global_step)), |
||
104 | os.path.join(log_dir, "model.max.ckpt")) |
||
105 | |||
106 | writer.add_summary(eval_summary, tf_global_step) |
||
107 | writer.add_summary(util.make_summary( |
||
108 | {"max_eval_f1": max_f1}), tf_global_step) |
||
109 | |||
110 | print("[{}] evaL_f1={:.2f}, max_f1={:.2f}".format( |
||
111 | tf_global_step, eval_f1, max_f1)) |
||
112 | |||
116 |