1
|
|
|
from __future__ import print_function |
2
|
|
|
|
3
|
|
|
import datetime |
4
|
|
|
import logging |
5
|
|
|
from abc import ABCMeta, abstractmethod |
6
|
|
|
|
7
|
|
|
import progressbar |
8
|
|
|
from six import add_metaclass |
9
|
|
|
from toolz import first |
10
|
|
|
|
11
|
|
|
logger = logging.getLogger(__name__) |
12
|
|
|
|
13
|
|
|
|
14
|
|
|
def callback(func): |
15
|
|
|
func._is_callback = True |
16
|
|
|
return func |
17
|
|
|
|
18
|
|
|
|
19
|
|
|
class TrainingExtension(object): |
20
|
|
|
"""The base class for training extensions. |
21
|
|
|
|
22
|
|
|
An extension is a set of callbacks sharing a joint context that are |
23
|
|
|
invoked at certain stages of the training procedure. These callbacks |
24
|
|
|
typically add a certain functionality to the training procedure, |
25
|
|
|
e.g. running validation on auxiliary datasets or early stopping. |
26
|
|
|
|
27
|
|
|
Parameters |
28
|
|
|
---------- |
29
|
|
|
name : str, optional |
30
|
|
|
The name of the extension. The names are useful in order to |
31
|
|
|
distinguish between several extensions of the same type that |
32
|
|
|
belongs to the same main loop. By default the name is set to |
33
|
|
|
the name of the class. |
34
|
|
|
|
35
|
|
|
Attributes |
36
|
|
|
---------- |
37
|
|
|
main_loop : :class:`.MainLoop` |
38
|
|
|
The main loop to which the extension belongs. |
39
|
|
|
name : str |
40
|
|
|
The name of the extension. |
41
|
|
|
|
42
|
|
|
""" |
43
|
|
|
def __init__(self, name=None): |
44
|
|
|
if not name: |
45
|
|
|
name = self.__class__.__name__ |
46
|
|
|
self.name = name |
47
|
|
|
|
48
|
|
|
@property |
49
|
|
|
def main_loop(self): |
50
|
|
|
if not hasattr(self, '_main_loop'): |
51
|
|
|
raise ValueError("main loop must be assigned to extension first") |
52
|
|
|
return self._main_loop |
53
|
|
|
|
54
|
|
|
@main_loop.setter |
55
|
|
|
def main_loop(self, value): |
56
|
|
|
self._main_loop = value |
57
|
|
|
|
58
|
|
|
def dispatch(self, callback_name, *args): |
59
|
|
|
"""Runs callback with the given name. |
60
|
|
|
|
61
|
|
|
The reason for having this method is to allow |
62
|
|
|
the descendants of the :class:`TrainingExtension` to intercept |
63
|
|
|
callback invocations and do something with them, e.g. block |
64
|
|
|
when certain condition does not hold. The default implementation |
65
|
|
|
simply invokes the callback by its name. |
66
|
|
|
|
67
|
|
|
""" |
68
|
|
|
getattr(self, str(callback_name))(*args) |
69
|
|
|
|
70
|
|
|
@callback |
71
|
|
|
def on_resumption(self): |
72
|
|
|
"""The callback invoked after training is resumed.""" |
73
|
|
|
pass |
74
|
|
|
|
75
|
|
|
@callback |
76
|
|
|
def on_error(self, exception): |
77
|
|
|
"""The callback invoked when an error occurs. |
78
|
|
|
|
79
|
|
|
Parameters |
80
|
|
|
---------- |
81
|
|
|
exception : object |
82
|
|
|
Exception occurred during the main loop run. |
83
|
|
|
|
84
|
|
|
""" |
85
|
|
|
|
86
|
|
|
@callback |
87
|
|
|
def before_training(self): |
88
|
|
|
"""The callback invoked before training is started.""" |
89
|
|
|
pass |
90
|
|
|
|
91
|
|
|
@callback |
92
|
|
|
def before_epoch(self): |
93
|
|
|
"""The callback invoked before starting an epoch.""" |
94
|
|
|
pass |
95
|
|
|
|
96
|
|
|
@callback |
97
|
|
|
def before_batch(self, batch): |
98
|
|
|
"""The callback invoked before a batch is processed. |
99
|
|
|
|
100
|
|
|
Parameters |
101
|
|
|
---------- |
102
|
|
|
batch : object |
103
|
|
|
The data batch to be processed. |
104
|
|
|
|
105
|
|
|
""" |
106
|
|
|
pass |
107
|
|
|
|
108
|
|
|
@callback |
109
|
|
|
def after_batch(self, batch): |
110
|
|
|
"""The callback invoked after a batch is processed. |
111
|
|
|
|
112
|
|
|
Parameters |
113
|
|
|
---------- |
114
|
|
|
batch : object |
115
|
|
|
The data batch just processed. |
116
|
|
|
|
117
|
|
|
""" |
118
|
|
|
pass |
119
|
|
|
|
120
|
|
|
@callback |
121
|
|
|
def after_epoch(self): |
122
|
|
|
"""The callback invoked after an epoch is finished.""" |
123
|
|
|
pass |
124
|
|
|
|
125
|
|
|
@callback |
126
|
|
|
def after_training(self): |
127
|
|
|
"""The callback invoked after training is finished.""" |
128
|
|
|
pass |
129
|
|
|
|
130
|
|
|
@callback |
131
|
|
|
def on_interrupt(self): |
132
|
|
|
"""The callback invoked when training is interrupted.""" |
133
|
|
|
pass |
134
|
|
|
|
135
|
|
|
|
136
|
|
|
class CallbackName(str): |
137
|
|
|
"""A name of a TrainingExtension callback. |
138
|
|
|
|
139
|
|
|
Raises |
140
|
|
|
------ |
141
|
|
|
:class:`TypeError` on comparison with a string which is not a name of |
142
|
|
|
TrainingExtension callback. |
143
|
|
|
|
144
|
|
|
""" |
145
|
|
|
def __eq__(self, other): |
146
|
|
|
callback_names = [key for key, value |
147
|
|
|
in TrainingExtension.__dict__.items() |
148
|
|
|
if getattr(value, '_is_callback', False)] |
149
|
|
|
if other not in callback_names: |
150
|
|
|
raise TypeError("{} is not a valid callback.".format(other)) |
151
|
|
|
return str(self) == other |
152
|
|
|
|
153
|
|
|
|
154
|
|
|
class Predicate(object): |
155
|
|
|
def __init__(self, condition, num): |
156
|
|
|
self.condition = condition |
157
|
|
|
self.num = num |
158
|
|
|
|
159
|
|
|
def __call__(self, log): |
160
|
|
|
if self.condition.endswith('epochs'): |
161
|
|
|
entry = log.status['epochs_done'] |
162
|
|
|
else: |
163
|
|
|
entry = log.status['iterations_done'] |
164
|
|
|
if self.condition.startswith('every'): |
165
|
|
|
return entry % self.num == 0 |
166
|
|
|
else: |
167
|
|
|
return entry == self.num |
168
|
|
|
|
169
|
|
|
|
170
|
|
|
def has_done_epochs(log): |
171
|
|
|
return log.status['epochs_done'] == 0 |
172
|
|
|
|
173
|
|
|
|
174
|
|
|
def always_true(log): |
175
|
|
|
return True |
176
|
|
|
|
177
|
|
|
|
178
|
|
|
@add_metaclass(ABCMeta) |
179
|
|
|
class SimpleExtension(TrainingExtension): |
180
|
|
|
"""A base class for simple extensions. |
181
|
|
|
|
182
|
|
|
All logic of simple extensions is concentrated in the method |
183
|
|
|
:meth:`do`. This method is called when certain conditions are |
184
|
|
|
fulfilled. The user can manage the conditions by calling the |
185
|
|
|
`add_condition` method and by passing arguments to the constructor. In |
186
|
|
|
addition to specifying when :meth:`do` is called, it is possible to |
187
|
|
|
specify additional arguments passed to :meth:`do` under different |
188
|
|
|
conditions. |
189
|
|
|
|
190
|
|
|
Parameters |
191
|
|
|
---------- |
192
|
|
|
before_training : bool |
193
|
|
|
If ``True``, :meth:`do` is invoked before training. |
194
|
|
|
before_first_epoch : bool |
195
|
|
|
If ``True``, :meth:`do` is invoked before the first epoch. |
196
|
|
|
before_epoch : bool |
197
|
|
|
If ``True``, :meth:`do` is invoked before every epoch. |
198
|
|
|
on_resumption : bool, optional |
199
|
|
|
If ``True``, :meth:`do` is invoked when training is resumed. |
200
|
|
|
on_interrupt : bool, optional |
201
|
|
|
If ``True``, :meth:`do` is invoked when training is interrupted. |
202
|
|
|
after_epoch : bool |
203
|
|
|
If ``True``, :meth:`do` is invoked after every epoch. |
204
|
|
|
after_batch: bool |
205
|
|
|
If ``True``, :meth:`do` is invoked after every batch. |
206
|
|
|
after_training : bool |
207
|
|
|
If ``True``, :meth:`do` is invoked after training. |
208
|
|
|
after_n_epochs : int, optional |
209
|
|
|
If not ``None``, :meth:`do` is invoked when `after_n_epochs` |
210
|
|
|
epochs are done. |
211
|
|
|
every_n_epochs : int, optional |
212
|
|
|
If not ``None``, :meth:`do` is invoked after every n-th epoch. |
213
|
|
|
after_n_batches : int, optional |
214
|
|
|
If not ``None``, :meth:`do` is invoked when `after_n_batches` |
215
|
|
|
batches are processed. |
216
|
|
|
every_n_batches : int, optional |
217
|
|
|
If not ``None``, :meth:`do` is invoked after every n-th batch. |
218
|
|
|
|
219
|
|
|
""" |
220
|
|
|
BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch", |
221
|
|
|
"before_epoch", "before_batch", |
222
|
|
|
"on_resumption", "on_interrupt", |
223
|
|
|
"after_epoch", "after_batch", |
224
|
|
|
"after_training"]) |
225
|
|
|
|
226
|
|
|
INTEGER_TRIGGERS = frozenset(["after_n_epochs", "after_n_batches", |
227
|
|
|
"every_n_epochs", "every_n_batches"]) |
228
|
|
|
|
229
|
|
|
def __init__(self, **kwargs): |
230
|
|
|
self._conditions = [] |
231
|
|
|
super_kwargs = {} |
232
|
|
|
trigger_keywords = self.BOOLEAN_TRIGGERS | self.INTEGER_TRIGGERS |
233
|
|
|
conditions = {} |
234
|
|
|
for key, value in kwargs.items(): |
235
|
|
|
if key in trigger_keywords: |
236
|
|
|
conditions[key] = value |
237
|
|
|
else: |
238
|
|
|
super_kwargs[key] = value |
239
|
|
|
self.set_conditions(**conditions) |
240
|
|
|
super(SimpleExtension, self).__init__(**super_kwargs) |
241
|
|
|
|
242
|
|
|
def set_conditions(self, **kwargs): |
243
|
|
|
"""Set the conditions for which this extension should be run. |
244
|
|
|
|
245
|
|
|
Parameters |
246
|
|
|
---------- |
247
|
|
|
See the :class:`SimpleExtension` docstring for a list of |
248
|
|
|
possible parameters. |
249
|
|
|
|
250
|
|
|
""" |
251
|
|
|
self._conditions[:] = [] |
252
|
|
|
predicates = {'before_first_epoch': has_done_epochs} |
253
|
|
|
conditions = { |
254
|
|
|
'before_first_epoch': 'before_epoch', |
255
|
|
|
'after_epoch': 'after_epoch', |
256
|
|
|
'after_batch': 'after_batch', |
257
|
|
|
'every_n_batches': 'after_batch', |
258
|
|
|
'every_n_epochs': 'after_epoch', |
259
|
|
|
'after_n_batches': 'after_batch', |
260
|
|
|
'after_n_epochs': 'after_epoch' |
261
|
|
|
} |
262
|
|
|
# Freeze the keys as a list so that we can safely modify kwargs. |
263
|
|
|
for key, value in kwargs.items(): |
264
|
|
|
if value: |
265
|
|
|
if key in self.BOOLEAN_TRIGGERS: |
266
|
|
|
self.add_condition([conditions.get(key, key)], |
267
|
|
|
predicate=predicates.get(key, None)) |
268
|
|
|
elif key in self.INTEGER_TRIGGERS: |
269
|
|
|
predicate = Predicate(key, value) |
270
|
|
|
self.add_condition([conditions.get(key, key)], |
271
|
|
|
predicate=predicate) |
272
|
|
|
else: |
273
|
|
|
raise KeyError("Invalid condition: {}".format(key)) |
274
|
|
|
return self # For chaining calls. |
275
|
|
|
|
276
|
|
|
def add_condition(self, callbacks_names, predicate=None, arguments=None): |
277
|
|
|
"""Adds a condition under which a :meth:`do` is called. |
278
|
|
|
|
279
|
|
|
Parameters |
280
|
|
|
---------- |
281
|
|
|
callbacks_names : list of str |
282
|
|
|
The names of the callback in which the method. |
283
|
|
|
predicate : function |
284
|
|
|
A predicate function the main loop's log as the |
285
|
|
|
single parameter and returning ``True`` when the method |
286
|
|
|
should be called and ``False`` when should not. If ``None``, |
287
|
|
|
an always ``True`` predicate is used. |
288
|
|
|
arguments : iterable |
289
|
|
|
Additional arguments to be passed to :meth:`do`. They will |
290
|
|
|
be concatenated with the ones passed from the main loop |
291
|
|
|
(e.g. the batch in case of `after_epoch` callback). |
292
|
|
|
|
293
|
|
|
Returns |
294
|
|
|
------- |
295
|
|
|
The extension object (allow chaining calls) |
296
|
|
|
|
297
|
|
|
""" |
298
|
|
|
if not isinstance(callbacks_names, (list, tuple)): |
299
|
|
|
raise ValueError("callbacks_names must be list or tuple.") |
300
|
|
|
for _callback_name in callbacks_names: |
301
|
|
|
if not arguments: |
302
|
|
|
arguments = [] |
303
|
|
|
if not predicate: |
304
|
|
|
self._conditions.append((_callback_name, always_true, |
305
|
|
|
arguments)) |
306
|
|
|
else: |
307
|
|
|
self._conditions.append((_callback_name, predicate, |
308
|
|
|
arguments)) |
309
|
|
|
return self |
310
|
|
|
|
311
|
|
|
@abstractmethod |
312
|
|
|
def do(self, which_callback, *args): |
313
|
|
|
r"""Does the job of the training extension. |
314
|
|
|
|
315
|
|
|
Parameters |
316
|
|
|
---------- |
317
|
|
|
which_callback : str |
318
|
|
|
The name of the callback in the context of which :meth:`do` is |
319
|
|
|
run. |
320
|
|
|
\*args : tuple |
321
|
|
|
The arguments from the main loop concatenated with additional |
322
|
|
|
arguments from user. |
323
|
|
|
|
324
|
|
|
Notes |
325
|
|
|
----- |
326
|
|
|
Subclasses *must* accept additional positional arguments in their |
327
|
|
|
call signature for this method, even if they are unused. |
328
|
|
|
|
329
|
|
|
""" |
330
|
|
|
pass |
331
|
|
|
|
332
|
|
|
def dispatch(self, callback_invoked, *from_main_loop): |
333
|
|
|
"""Check conditions and call the :meth:`do` method. |
334
|
|
|
|
335
|
|
|
Also adds additional arguments if specified for a condition. |
336
|
|
|
|
337
|
|
|
.. todo:: |
338
|
|
|
|
339
|
|
|
Add a check for a situation when several conditions are met |
340
|
|
|
at the same time and do something. |
341
|
|
|
|
342
|
|
|
""" |
343
|
|
|
for callback_name, predicate, arguments in self._conditions: |
344
|
|
|
if (callback_name == callback_invoked and |
345
|
|
|
predicate(self.main_loop.log)): |
346
|
|
|
self.do(callback_invoked, *(from_main_loop + tuple(arguments))) |
347
|
|
|
|
348
|
|
|
@staticmethod |
349
|
|
|
def parse_args(which_callback, args): |
350
|
|
|
"""Separates :meth:`do` arguments coming from different sources. |
351
|
|
|
|
352
|
|
|
When a :meth:`do` method receives arguments from both the main |
353
|
|
|
loop (e.g. a batch) and the user, it often has to separate them. |
354
|
|
|
This method is the right tool to use. |
355
|
|
|
|
356
|
|
|
Parameters |
357
|
|
|
---------- |
358
|
|
|
which_callback : str |
359
|
|
|
The name of the callback. |
360
|
|
|
args : iterable |
361
|
|
|
The arguments. |
362
|
|
|
|
363
|
|
|
Returns |
364
|
|
|
------- |
365
|
|
|
from_main_loop : tuple |
366
|
|
|
from_user : tuple |
367
|
|
|
|
368
|
|
|
""" |
369
|
|
|
args = tuple(args) |
370
|
|
|
if (which_callback == 'after_batch' or |
371
|
|
|
which_callback == 'before_batch'): |
372
|
|
|
return (args[0],), args[1:] |
373
|
|
|
return (), args |
374
|
|
|
|
375
|
|
|
|
376
|
|
|
class CompositeExtension(SimpleExtension): |
377
|
|
|
"""An extension that manages several other extensions. |
378
|
|
|
|
379
|
|
|
Parameters |
380
|
|
|
---------- |
381
|
|
|
sub_extensions : iterable |
382
|
|
|
An iterable collection of sub-extensions to manage. |
383
|
|
|
run_before_children : bool, optional |
384
|
|
|
Whether the container extension's own logic should |
385
|
|
|
be dispatched before that of the sub-extensions. |
386
|
|
|
If ``False``, the containing extension is dispatched last. |
387
|
|
|
Defaults to ``True``. |
388
|
|
|
|
389
|
|
|
Notes |
390
|
|
|
----- |
391
|
|
|
The main use case for this class is bundling together groups |
392
|
|
|
of extensions that are most commonly used in tandem, configured |
393
|
|
|
so as to interact with one another. Encapsulating this pattern |
394
|
|
|
in a single extension reduces boilerplate. |
395
|
|
|
|
396
|
|
|
Sub-extensions are dispatched in the order specified in |
397
|
|
|
``sub_extensions``, on whatever triggers they are individually |
398
|
|
|
configured to respect. |
399
|
|
|
|
400
|
|
|
Sub-extensions may be run on different triggers than the containing |
401
|
|
|
extension; the trigger keywords passed to the constructor |
402
|
|
|
for this class only affect the outer extension's logic, and |
403
|
|
|
sub-extensions should be configured independently (possibly in |
404
|
|
|
a constructor for a subclass of :class:`CompositeExtension`). |
405
|
|
|
|
406
|
|
|
""" |
407
|
|
|
def __init__(self, sub_extensions, run_before_children=True, **kwargs): |
408
|
|
|
self.sub_extensions = sub_extensions |
409
|
|
|
self.run_before_children = run_before_children |
410
|
|
|
super(CompositeExtension, self).__init__(**kwargs) |
411
|
|
|
|
412
|
|
|
def dispatch(self, callback_invoked, *from_main_loop): |
413
|
|
|
def run_super(): |
414
|
|
|
super(CompositeExtension, self).dispatch(callback_invoked, |
415
|
|
|
*from_main_loop) |
416
|
|
|
if self.run_before_children: |
417
|
|
|
run_super() |
418
|
|
|
|
419
|
|
|
for ext in self.sub_extensions: |
420
|
|
|
ext.dispatch(callback_invoked, *from_main_loop) |
421
|
|
|
|
422
|
|
|
if not self.run_before_children: |
423
|
|
|
run_super() |
424
|
|
|
|
425
|
|
|
@property |
426
|
|
|
def main_loop(self): |
427
|
|
|
return super(CompositeExtension, self).main_loop |
428
|
|
|
|
429
|
|
|
@main_loop.setter |
430
|
|
|
def main_loop(self, value): |
431
|
|
|
self._main_loop = value |
432
|
|
|
for sub in self.sub_extensions: |
433
|
|
|
sub.main_loop = value |
434
|
|
|
|
435
|
|
|
def do(self, which_callback, *args): |
436
|
|
|
pass |
437
|
|
|
|
438
|
|
|
|
439
|
|
|
class FinishAfter(SimpleExtension): |
440
|
|
|
"""Finishes the training process when triggered.""" |
441
|
|
|
def __init__(self, **kwargs): |
442
|
|
|
super(FinishAfter, self).__init__(**kwargs) |
443
|
|
|
|
444
|
|
|
def do(self, which_callback, *args): |
445
|
|
|
self.main_loop.log.current_row['training_finish_requested'] = True |
446
|
|
|
|
447
|
|
|
|
448
|
|
|
class Printing(SimpleExtension): |
449
|
|
|
"""Prints log messages to the screen.""" |
450
|
|
|
def __init__(self, **kwargs): |
451
|
|
|
kwargs.setdefault("before_first_epoch", True) |
452
|
|
|
kwargs.setdefault("on_resumption", True) |
453
|
|
|
kwargs.setdefault("after_training", True) |
454
|
|
|
kwargs.setdefault("after_epoch", True) |
455
|
|
|
kwargs.setdefault("on_interrupt", True) |
456
|
|
|
super(Printing, self).__init__(**kwargs) |
457
|
|
|
|
458
|
|
|
def _print_attributes(self, attribute_tuples): |
459
|
|
|
for attr, value in sorted(attribute_tuples.items(), key=first): |
460
|
|
|
if not attr.startswith("_"): |
461
|
|
|
print("\t", "{}:".format(attr), value) |
462
|
|
|
|
463
|
|
|
def do(self, which_callback, *args): |
464
|
|
|
log = self.main_loop.log |
465
|
|
|
print_status = True |
466
|
|
|
|
467
|
|
|
print() |
468
|
|
|
print("".join(79 * "-")) |
469
|
|
|
if which_callback == "before_epoch" and log.status['epochs_done'] == 0: |
470
|
|
|
print("BEFORE FIRST EPOCH") |
471
|
|
|
elif which_callback == "on_resumption": |
472
|
|
|
print("TRAINING HAS BEEN RESUMED") |
473
|
|
|
elif which_callback == "after_training": |
474
|
|
|
print("TRAINING HAS BEEN FINISHED:") |
475
|
|
|
elif which_callback == "after_epoch": |
476
|
|
|
print("AFTER ANOTHER EPOCH") |
477
|
|
|
elif which_callback == "on_interrupt": |
478
|
|
|
print("TRAINING HAS BEEN INTERRUPTED") |
479
|
|
|
print_status = False |
480
|
|
|
print("".join(79 * "-")) |
481
|
|
|
if print_status: |
482
|
|
|
print("Training status:") |
483
|
|
|
self._print_attributes(log.status) |
484
|
|
|
print("Log records from the iteration {}:".format( |
485
|
|
|
log.status['iterations_done'])) |
486
|
|
|
self._print_attributes(log.current_row) |
487
|
|
|
print() |
488
|
|
|
|
489
|
|
|
|
490
|
|
|
class ProgressBar(TrainingExtension): |
491
|
|
|
"""Display a progress bar during training. |
492
|
|
|
|
493
|
|
|
This extension tries to infer the number of iterations per epoch |
494
|
|
|
by querying the `num_batches`, `num_examples` and `batch_size` |
495
|
|
|
attributes from the :class:`IterationScheme`. When this information is |
496
|
|
|
not available it will display a simplified progress bar that does not |
497
|
|
|
include the estimated time until the end of this epoch. |
498
|
|
|
|
499
|
|
|
Notes |
500
|
|
|
----- |
501
|
|
|
This extension should be run before other extensions that print to |
502
|
|
|
the screen at the end or at the beginning of the epoch (e.g. the |
503
|
|
|
:class:`Printing` extension). Placing ProgressBar before these |
504
|
|
|
extension will ensure you won't get intermingled output on your |
505
|
|
|
terminal. |
506
|
|
|
|
507
|
|
|
""" |
508
|
|
|
def __init__(self, **kwargs): |
509
|
|
|
super(ProgressBar, self).__init__(**kwargs) |
510
|
|
|
self.bar = None |
511
|
|
|
self.iter_count = 0 |
512
|
|
|
|
513
|
|
|
def __getstate__(self): |
514
|
|
|
# Ensure we won't pickle the actual progress bar. |
515
|
|
|
# (It might contain unpicklable file handles) |
516
|
|
|
state = dict(self.__dict__) |
517
|
|
|
del state['bar'] |
518
|
|
|
return state |
519
|
|
|
|
520
|
|
|
def __setstate__(self, state): |
521
|
|
|
self.__dict__.update(state) |
522
|
|
|
self.bar = None |
523
|
|
|
|
524
|
|
|
def get_iter_per_epoch(self): |
525
|
|
|
"""Try to infer the number of iterations per epoch.""" |
526
|
|
|
iter_scheme = self.main_loop.data_stream.iteration_scheme |
527
|
|
|
if hasattr(iter_scheme, 'num_batches'): |
528
|
|
|
return iter_scheme.num_batches |
529
|
|
|
elif (hasattr(iter_scheme, 'num_examples') and |
530
|
|
|
hasattr(iter_scheme, 'batch_size')): |
531
|
|
|
return iter_scheme.num_examples // iter_scheme.batch_size |
532
|
|
|
return None |
533
|
|
|
|
534
|
|
|
def create_bar(self): |
535
|
|
|
"""Create a new progress bar. |
536
|
|
|
|
537
|
|
|
Calls `self.get_iter_per_epoch()`, selects an appropriate |
538
|
|
|
set of widgets and creates a ProgressBar. |
539
|
|
|
|
540
|
|
|
""" |
541
|
|
|
iter_per_epoch = self.get_iter_per_epoch() |
542
|
|
|
epochs_done = self.main_loop.log.status['epochs_done'] |
543
|
|
|
|
544
|
|
|
if iter_per_epoch is None: |
545
|
|
|
widgets = ["Epoch {}, step ".format(epochs_done), |
546
|
|
|
progressbar.Counter(), ' ', |
547
|
|
|
progressbar.BouncingBar(), ' ', |
548
|
|
|
progressbar.Timer()] |
549
|
|
|
iter_per_epoch = progressbar.UnknownLength |
550
|
|
|
else: |
551
|
|
|
widgets = ["Epoch {}, step ".format(epochs_done), |
552
|
|
|
progressbar.Counter(), |
553
|
|
|
' (', progressbar.Percentage(), ') ', |
554
|
|
|
progressbar.Bar(), ' ', |
555
|
|
|
progressbar.Timer(), ' ', progressbar.ETA()] |
556
|
|
|
|
557
|
|
|
return progressbar.ProgressBar(widgets=widgets, |
558
|
|
|
max_value=iter_per_epoch) |
559
|
|
|
|
560
|
|
|
def before_epoch(self): |
561
|
|
|
self.iter_count = 0 |
562
|
|
|
|
563
|
|
|
def after_epoch(self): |
564
|
|
|
if self.bar is None: |
565
|
|
|
return |
566
|
|
|
|
567
|
|
|
self.bar.finish() |
568
|
|
|
self.bar = None |
569
|
|
|
|
570
|
|
|
def before_batch(self, batch): |
571
|
|
|
if self.bar is None: |
572
|
|
|
self.bar = self.create_bar() |
573
|
|
|
self.bar.start() |
574
|
|
|
|
575
|
|
|
self.iter_count += 1 |
576
|
|
|
self.bar.update(self.iter_count) |
577
|
|
|
|
578
|
|
|
|
579
|
|
|
class Timing(SimpleExtension): |
580
|
|
|
"""Add timing information to the log. |
581
|
|
|
|
582
|
|
|
This adds data about the time spent in the algorithm's |
583
|
|
|
:meth:`~.Algorithm.process_batch` method as well as the time spent |
584
|
|
|
reading data per batch or epoch. It also reports the time spent |
585
|
|
|
initializing the algorithm. |
586
|
|
|
|
587
|
|
|
Parameters |
588
|
|
|
---------- |
589
|
|
|
prefix : str |
590
|
|
|
Prefix to be added to the log record. Defaults to the empty string. |
591
|
|
|
|
592
|
|
|
Notes |
593
|
|
|
----- |
594
|
|
|
Add this extension *before* the :class:`Printing` extension. |
595
|
|
|
|
596
|
|
|
Created with callbacks like ``every_n_batches`` this extension |
597
|
|
|
averages the time. |
598
|
|
|
|
599
|
|
|
This extension does *not* enable full profiling information. To see a |
600
|
|
|
full profile of the main loop at the end of training, use the |
601
|
|
|
``profile`` configuration (e.g. by setting ``BLOCKS_PROFILE=true``). |
602
|
|
|
|
603
|
|
|
""" |
604
|
|
|
def __init__(self, prefix="", **kwargs): |
605
|
|
|
kwargs.setdefault('before_first_epoch', True) |
606
|
|
|
kwargs.setdefault('after_epoch', True) |
607
|
|
|
super(Timing, self).__init__(**kwargs) |
608
|
|
|
|
609
|
|
|
def init_dict(): |
610
|
|
|
return { |
611
|
|
|
level: {'train': 0, 'read_data': 0} |
612
|
|
|
for level in ['batch', 'epoch']} |
613
|
|
|
self.current = init_dict() |
614
|
|
|
self.previous = init_dict() |
615
|
|
|
self.current_index = init_dict() |
616
|
|
|
self.previous_index = init_dict() |
617
|
|
|
self.prefix = prefix |
618
|
|
|
if self.prefix: |
619
|
|
|
self.prefix += '_' |
620
|
|
|
|
621
|
|
|
def do(self, which_callback, *args): |
622
|
|
|
current_row = self.main_loop.log.current_row |
623
|
|
|
profile = self.main_loop.profile.total |
624
|
|
|
|
625
|
|
|
if which_callback == 'before_epoch': |
626
|
|
|
current_row['time_initialization'] = profile[('initialization',)] |
627
|
|
|
return |
628
|
|
|
if which_callback == 'after_batch': |
629
|
|
|
level = 'batch' |
630
|
|
|
counter = 'iterations_done' |
631
|
|
|
elif which_callback == 'after_epoch': |
632
|
|
|
level = 'epoch' |
633
|
|
|
counter = 'epochs_done' |
634
|
|
|
else: |
635
|
|
|
raise ValueError('wrong callback type `{}`'.format(which_callback)) |
636
|
|
|
for action in ['train', 'read_data']: |
637
|
|
|
self.previous_index[level][action] = ( |
638
|
|
|
self.current_index[level][action]) |
639
|
|
|
self.current_index[level][action] = ( |
640
|
|
|
self.main_loop.log.status[counter]) |
641
|
|
|
current_index = self.current_index[level][action] |
642
|
|
|
previous_index = self.previous_index[level][action] |
643
|
|
|
if current_index == previous_index: |
644
|
|
|
logger.debug('Timing extension was called twice this %s, ' |
645
|
|
|
'log was not updated.', level) |
646
|
|
|
# Nothing to report for this level |
647
|
|
|
continue |
648
|
|
|
|
649
|
|
|
self.previous[level][action] = self.current[level][action] |
650
|
|
|
self.current[level][action] = profile['training', 'epoch', action] |
651
|
|
|
|
652
|
|
|
this_time = self.prefix + 'time_{}_this_{}' |
653
|
|
|
current_row[this_time.format(action, level)] = ( |
654
|
|
|
(self.current[level][action] - self.previous[level][action]) / |
655
|
|
|
(current_index - previous_index)) |
656
|
|
|
total_time = self.prefix + 'time_{}_total' |
657
|
|
|
current_row[total_time.format(action)] = \ |
658
|
|
|
self.current[level][action] |
659
|
|
|
|
660
|
|
|
|
661
|
|
|
class Timestamp(SimpleExtension): |
662
|
|
|
"""Adds a human readable (ISO 8601) timestamp to the log. |
663
|
|
|
|
664
|
|
|
Parameters |
665
|
|
|
---------- |
666
|
|
|
log_record : str, optional |
667
|
|
|
The record name to use. Defaults to 'timestamp'. |
668
|
|
|
separator : str, optional |
669
|
|
|
Separator between the date and time. ISO 8601 specifies 'T'. |
670
|
|
|
Here, we default to ' ' (blank space) for human readability. |
671
|
|
|
|
672
|
|
|
""" |
673
|
|
|
DEFAULT_LOG_RECORD = 'timestamp' |
674
|
|
|
|
675
|
|
|
def __init__(self, log_record=DEFAULT_LOG_RECORD, separator=' ', |
676
|
|
|
**kwargs): |
677
|
|
|
self.log_record = log_record |
678
|
|
|
self.separator = separator |
679
|
|
|
kwargs.setdefault('after_epoch', True) |
680
|
|
|
super(Timestamp, self).__init__(**kwargs) |
681
|
|
|
|
682
|
|
|
def do(self, *args): |
683
|
|
|
self.main_loop.log.current_row[self.log_record] = self.get_timestamp() |
684
|
|
|
|
685
|
|
|
def get_timestamp(self): |
686
|
|
|
# Separated into a method to override for ease of testing. |
687
|
|
|
return datetime.datetime.isoformat(self.separator) |
688
|
|
|
|