|
1
|
|
|
from __future__ import print_function |
|
2
|
|
|
|
|
3
|
|
|
import datetime |
|
4
|
|
|
import logging |
|
5
|
|
|
from abc import ABCMeta, abstractmethod |
|
6
|
|
|
|
|
7
|
|
|
import progressbar |
|
8
|
|
|
from six import add_metaclass |
|
9
|
|
|
from toolz import first |
|
10
|
|
|
|
|
11
|
|
|
logger = logging.getLogger(__name__) |
|
12
|
|
|
|
|
13
|
|
|
|
|
14
|
|
|
def callback(func): |
|
15
|
|
|
func._is_callback = True |
|
16
|
|
|
return func |
|
17
|
|
|
|
|
18
|
|
|
|
|
19
|
|
|
class TrainingExtension(object): |
|
20
|
|
|
"""The base class for training extensions. |
|
21
|
|
|
|
|
22
|
|
|
An extension is a set of callbacks sharing a joint context that are |
|
23
|
|
|
invoked at certain stages of the training procedure. These callbacks |
|
24
|
|
|
typically add a certain functionality to the training procedure, |
|
25
|
|
|
e.g. running validation on auxiliary datasets or early stopping. |
|
26
|
|
|
|
|
27
|
|
|
Parameters |
|
28
|
|
|
---------- |
|
29
|
|
|
name : str, optional |
|
30
|
|
|
The name of the extension. The names are useful in order to |
|
31
|
|
|
distinguish between several extensions of the same type that |
|
32
|
|
|
belongs to the same main loop. By default the name is set to |
|
33
|
|
|
the name of the class. |
|
34
|
|
|
|
|
35
|
|
|
Attributes |
|
36
|
|
|
---------- |
|
37
|
|
|
main_loop : :class:`.MainLoop` |
|
38
|
|
|
The main loop to which the extension belongs. |
|
39
|
|
|
name : str |
|
40
|
|
|
The name of the extension. |
|
41
|
|
|
|
|
42
|
|
|
""" |
|
43
|
|
|
def __init__(self, name=None): |
|
44
|
|
|
if not name: |
|
45
|
|
|
name = self.__class__.__name__ |
|
46
|
|
|
self.name = name |
|
47
|
|
|
|
|
48
|
|
|
@property |
|
49
|
|
|
def main_loop(self): |
|
50
|
|
|
if not hasattr(self, '_main_loop'): |
|
51
|
|
|
raise ValueError("main loop must be assigned to extension first") |
|
52
|
|
|
return self._main_loop |
|
53
|
|
|
|
|
54
|
|
|
@main_loop.setter |
|
55
|
|
|
def main_loop(self, value): |
|
56
|
|
|
self._main_loop = value |
|
57
|
|
|
|
|
58
|
|
|
def dispatch(self, callback_name, *args): |
|
59
|
|
|
"""Runs callback with the given name. |
|
60
|
|
|
|
|
61
|
|
|
The reason for having this method is to allow |
|
62
|
|
|
the descendants of the :class:`TrainingExtension` to intercept |
|
63
|
|
|
callback invocations and do something with them, e.g. block |
|
64
|
|
|
when certain condition does not hold. The default implementation |
|
65
|
|
|
simply invokes the callback by its name. |
|
66
|
|
|
|
|
67
|
|
|
""" |
|
68
|
|
|
getattr(self, str(callback_name))(*args) |
|
69
|
|
|
|
|
70
|
|
|
@callback |
|
71
|
|
|
def on_resumption(self): |
|
72
|
|
|
"""The callback invoked after training is resumed.""" |
|
73
|
|
|
pass |
|
74
|
|
|
|
|
75
|
|
|
@callback |
|
76
|
|
|
def on_error(self, exception): |
|
77
|
|
|
"""The callback invoked when an error occurs. |
|
78
|
|
|
|
|
79
|
|
|
Parameters |
|
80
|
|
|
---------- |
|
81
|
|
|
exception : object |
|
82
|
|
|
Exception occurred during the main loop run. |
|
83
|
|
|
|
|
84
|
|
|
""" |
|
85
|
|
|
|
|
86
|
|
|
@callback |
|
87
|
|
|
def before_training(self): |
|
88
|
|
|
"""The callback invoked before training is started.""" |
|
89
|
|
|
pass |
|
90
|
|
|
|
|
91
|
|
|
@callback |
|
92
|
|
|
def before_epoch(self): |
|
93
|
|
|
"""The callback invoked before starting an epoch.""" |
|
94
|
|
|
pass |
|
95
|
|
|
|
|
96
|
|
|
@callback |
|
97
|
|
|
def before_batch(self, batch): |
|
98
|
|
|
"""The callback invoked before a batch is processed. |
|
99
|
|
|
|
|
100
|
|
|
Parameters |
|
101
|
|
|
---------- |
|
102
|
|
|
batch : object |
|
103
|
|
|
The data batch to be processed. |
|
104
|
|
|
|
|
105
|
|
|
""" |
|
106
|
|
|
pass |
|
107
|
|
|
|
|
108
|
|
|
@callback |
|
109
|
|
|
def after_batch(self, batch): |
|
110
|
|
|
"""The callback invoked after a batch is processed. |
|
111
|
|
|
|
|
112
|
|
|
Parameters |
|
113
|
|
|
---------- |
|
114
|
|
|
batch : object |
|
115
|
|
|
The data batch just processed. |
|
116
|
|
|
|
|
117
|
|
|
""" |
|
118
|
|
|
pass |
|
119
|
|
|
|
|
120
|
|
|
@callback |
|
121
|
|
|
def after_epoch(self): |
|
122
|
|
|
"""The callback invoked after an epoch is finished.""" |
|
123
|
|
|
pass |
|
124
|
|
|
|
|
125
|
|
|
@callback |
|
126
|
|
|
def after_training(self): |
|
127
|
|
|
"""The callback invoked after training is finished.""" |
|
128
|
|
|
pass |
|
129
|
|
|
|
|
130
|
|
|
@callback |
|
131
|
|
|
def on_interrupt(self): |
|
132
|
|
|
"""The callback invoked when training is interrupted.""" |
|
133
|
|
|
pass |
|
134
|
|
|
|
|
135
|
|
|
|
|
136
|
|
|
class CallbackName(str): |
|
137
|
|
|
"""A name of a TrainingExtension callback. |
|
138
|
|
|
|
|
139
|
|
|
Raises |
|
140
|
|
|
------ |
|
141
|
|
|
:class:`TypeError` on comparison with a string which is not a name of |
|
142
|
|
|
TrainingExtension callback. |
|
143
|
|
|
|
|
144
|
|
|
""" |
|
145
|
|
|
def __eq__(self, other): |
|
146
|
|
|
callback_names = [key for key, value |
|
147
|
|
|
in TrainingExtension.__dict__.items() |
|
148
|
|
|
if getattr(value, '_is_callback', False)] |
|
149
|
|
|
if other not in callback_names: |
|
150
|
|
|
raise TypeError("{} is not a valid callback.".format(other)) |
|
151
|
|
|
return str(self) == other |
|
152
|
|
|
|
|
153
|
|
|
|
|
154
|
|
|
class Predicate(object): |
|
155
|
|
|
def __init__(self, condition, num): |
|
156
|
|
|
self.condition = condition |
|
157
|
|
|
self.num = num |
|
158
|
|
|
|
|
159
|
|
|
def __call__(self, log): |
|
160
|
|
|
if self.condition.endswith('epochs'): |
|
161
|
|
|
entry = log.status['epochs_done'] |
|
162
|
|
|
else: |
|
163
|
|
|
entry = log.status['iterations_done'] |
|
164
|
|
|
if self.condition.startswith('every'): |
|
165
|
|
|
return entry % self.num == 0 |
|
166
|
|
|
else: |
|
167
|
|
|
return entry == self.num |
|
168
|
|
|
|
|
169
|
|
|
|
|
170
|
|
|
def has_done_epochs(log): |
|
171
|
|
|
return log.status['epochs_done'] == 0 |
|
172
|
|
|
|
|
173
|
|
|
|
|
174
|
|
|
def always_true(log): |
|
175
|
|
|
return True |
|
176
|
|
|
|
|
177
|
|
|
|
|
178
|
|
|
@add_metaclass(ABCMeta) |
|
179
|
|
|
class SimpleExtension(TrainingExtension): |
|
180
|
|
|
"""A base class for simple extensions. |
|
181
|
|
|
|
|
182
|
|
|
All logic of simple extensions is concentrated in the method |
|
183
|
|
|
:meth:`do`. This method is called when certain conditions are |
|
184
|
|
|
fulfilled. The user can manage the conditions by calling the |
|
185
|
|
|
`add_condition` method and by passing arguments to the constructor. In |
|
186
|
|
|
addition to specifying when :meth:`do` is called, it is possible to |
|
187
|
|
|
specify additional arguments passed to :meth:`do` under different |
|
188
|
|
|
conditions. |
|
189
|
|
|
|
|
190
|
|
|
Parameters |
|
191
|
|
|
---------- |
|
192
|
|
|
before_training : bool |
|
193
|
|
|
If ``True``, :meth:`do` is invoked before training. |
|
194
|
|
|
before_first_epoch : bool |
|
195
|
|
|
If ``True``, :meth:`do` is invoked before the first epoch. |
|
196
|
|
|
before_epoch : bool |
|
197
|
|
|
If ``True``, :meth:`do` is invoked before every epoch. |
|
198
|
|
|
on_resumption : bool, optional |
|
199
|
|
|
If ``True``, :meth:`do` is invoked when training is resumed. |
|
200
|
|
|
on_interrupt : bool, optional |
|
201
|
|
|
If ``True``, :meth:`do` is invoked when training is interrupted. |
|
202
|
|
|
after_epoch : bool |
|
203
|
|
|
If ``True``, :meth:`do` is invoked after every epoch. |
|
204
|
|
|
after_batch: bool |
|
205
|
|
|
If ``True``, :meth:`do` is invoked after every batch. |
|
206
|
|
|
after_training : bool |
|
207
|
|
|
If ``True``, :meth:`do` is invoked after training. |
|
208
|
|
|
after_n_epochs : int, optional |
|
209
|
|
|
If not ``None``, :meth:`do` is invoked when `after_n_epochs` |
|
210
|
|
|
epochs are done. |
|
211
|
|
|
every_n_epochs : int, optional |
|
212
|
|
|
If not ``None``, :meth:`do` is invoked after every n-th epoch. |
|
213
|
|
|
after_n_batches : int, optional |
|
214
|
|
|
If not ``None``, :meth:`do` is invoked when `after_n_batches` |
|
215
|
|
|
batches are processed. |
|
216
|
|
|
every_n_batches : int, optional |
|
217
|
|
|
If not ``None``, :meth:`do` is invoked after every n-th batch. |
|
218
|
|
|
|
|
219
|
|
|
""" |
|
220
|
|
|
BOOLEAN_TRIGGERS = frozenset(["before_training", "before_first_epoch", |
|
221
|
|
|
"before_epoch", "before_batch", |
|
222
|
|
|
"on_resumption", "on_interrupt", "on_error", |
|
223
|
|
|
"after_epoch", "after_batch", |
|
224
|
|
|
"after_training"]) |
|
225
|
|
|
|
|
226
|
|
|
INTEGER_TRIGGERS = frozenset(["after_n_epochs", "after_n_batches", |
|
227
|
|
|
"every_n_epochs", "every_n_batches"]) |
|
228
|
|
|
|
|
229
|
|
|
def __init__(self, **kwargs): |
|
230
|
|
|
self._conditions = [] |
|
231
|
|
|
super_kwargs = {} |
|
232
|
|
|
trigger_keywords = self.BOOLEAN_TRIGGERS | self.INTEGER_TRIGGERS |
|
233
|
|
|
conditions = {} |
|
234
|
|
|
for key, value in kwargs.items(): |
|
235
|
|
|
if key in trigger_keywords: |
|
236
|
|
|
conditions[key] = value |
|
237
|
|
|
else: |
|
238
|
|
|
super_kwargs[key] = value |
|
239
|
|
|
self.set_conditions(**conditions) |
|
240
|
|
|
super(SimpleExtension, self).__init__(**super_kwargs) |
|
241
|
|
|
|
|
242
|
|
|
def set_conditions(self, **kwargs): |
|
243
|
|
|
"""Set the conditions for which this extension should be run. |
|
244
|
|
|
|
|
245
|
|
|
Parameters |
|
246
|
|
|
---------- |
|
247
|
|
|
See the :class:`SimpleExtension` docstring for a list of |
|
248
|
|
|
possible parameters. |
|
249
|
|
|
|
|
250
|
|
|
""" |
|
251
|
|
|
self._conditions[:] = [] |
|
252
|
|
|
predicates = {'before_first_epoch': has_done_epochs} |
|
253
|
|
|
conditions = { |
|
254
|
|
|
'before_first_epoch': 'before_epoch', |
|
255
|
|
|
'after_epoch': 'after_epoch', |
|
256
|
|
|
'after_batch': 'after_batch', |
|
257
|
|
|
'every_n_batches': 'after_batch', |
|
258
|
|
|
'every_n_epochs': 'after_epoch', |
|
259
|
|
|
'after_n_batches': 'after_batch', |
|
260
|
|
|
'after_n_epochs': 'after_epoch' |
|
261
|
|
|
} |
|
262
|
|
|
# Freeze the keys as a list so that we can safely modify kwargs. |
|
263
|
|
|
for key, value in kwargs.items(): |
|
264
|
|
|
if value: |
|
265
|
|
|
if key in self.BOOLEAN_TRIGGERS: |
|
266
|
|
|
self.add_condition([conditions.get(key, key)], |
|
267
|
|
|
predicate=predicates.get(key, None)) |
|
268
|
|
|
elif key in self.INTEGER_TRIGGERS: |
|
269
|
|
|
predicate = Predicate(key, value) |
|
270
|
|
|
self.add_condition([conditions.get(key, key)], |
|
271
|
|
|
predicate=predicate) |
|
272
|
|
|
else: |
|
273
|
|
|
raise KeyError("Invalid condition: {}".format(key)) |
|
274
|
|
|
return self # For chaining calls. |
|
275
|
|
|
|
|
276
|
|
|
def add_condition(self, callbacks_names, predicate=None, arguments=None): |
|
277
|
|
|
"""Adds a condition under which a :meth:`do` is called. |
|
278
|
|
|
|
|
279
|
|
|
Parameters |
|
280
|
|
|
---------- |
|
281
|
|
|
callbacks_names : list of str |
|
282
|
|
|
The names of the callback in which the method. |
|
283
|
|
|
predicate : function |
|
284
|
|
|
A predicate function the main loop's log as the |
|
285
|
|
|
single parameter and returning ``True`` when the method |
|
286
|
|
|
should be called and ``False`` when should not. If ``None``, |
|
287
|
|
|
an always ``True`` predicate is used. |
|
288
|
|
|
arguments : iterable |
|
289
|
|
|
Additional arguments to be passed to :meth:`do`. They will |
|
290
|
|
|
be concatenated with the ones passed from the main loop |
|
291
|
|
|
(e.g. the batch in case of `after_epoch` callback). |
|
292
|
|
|
|
|
293
|
|
|
Returns |
|
294
|
|
|
------- |
|
295
|
|
|
The extension object (allow chaining calls) |
|
296
|
|
|
|
|
297
|
|
|
""" |
|
298
|
|
|
if not isinstance(callbacks_names, (list, tuple)): |
|
299
|
|
|
raise ValueError("callbacks_names must be list or tuple.") |
|
300
|
|
|
for _callback_name in callbacks_names: |
|
301
|
|
|
if not arguments: |
|
302
|
|
|
arguments = [] |
|
303
|
|
|
if not predicate: |
|
304
|
|
|
self._conditions.append((_callback_name, always_true, |
|
305
|
|
|
arguments)) |
|
306
|
|
|
else: |
|
307
|
|
|
self._conditions.append((_callback_name, predicate, |
|
308
|
|
|
arguments)) |
|
309
|
|
|
return self |
|
310
|
|
|
|
|
311
|
|
|
@abstractmethod |
|
312
|
|
|
def do(self, which_callback, *args): |
|
313
|
|
|
r"""Does the job of the training extension. |
|
314
|
|
|
|
|
315
|
|
|
Parameters |
|
316
|
|
|
---------- |
|
317
|
|
|
which_callback : str |
|
318
|
|
|
The name of the callback in the context of which :meth:`do` is |
|
319
|
|
|
run. |
|
320
|
|
|
\*args : tuple |
|
321
|
|
|
The arguments from the main loop concatenated with additional |
|
322
|
|
|
arguments from user. |
|
323
|
|
|
|
|
324
|
|
|
Notes |
|
325
|
|
|
----- |
|
326
|
|
|
Subclasses *must* accept additional positional arguments in their |
|
327
|
|
|
call signature for this method, even if they are unused. |
|
328
|
|
|
|
|
329
|
|
|
""" |
|
330
|
|
|
pass |
|
331
|
|
|
|
|
332
|
|
|
def dispatch(self, callback_invoked, *from_main_loop): |
|
333
|
|
|
"""Check conditions and call the :meth:`do` method. |
|
334
|
|
|
|
|
335
|
|
|
Also adds additional arguments if specified for a condition. |
|
336
|
|
|
|
|
337
|
|
|
.. todo:: |
|
338
|
|
|
|
|
339
|
|
|
Add a check for a situation when several conditions are met |
|
340
|
|
|
at the same time and do something. |
|
341
|
|
|
|
|
342
|
|
|
""" |
|
343
|
|
|
for callback_name, predicate, arguments in self._conditions: |
|
344
|
|
|
if (callback_name == callback_invoked and |
|
345
|
|
|
predicate(self.main_loop.log)): |
|
346
|
|
|
self.do(callback_invoked, *(from_main_loop + tuple(arguments))) |
|
347
|
|
|
|
|
348
|
|
|
@staticmethod |
|
349
|
|
|
def parse_args(which_callback, args): |
|
350
|
|
|
"""Separates :meth:`do` arguments coming from different sources. |
|
351
|
|
|
|
|
352
|
|
|
When a :meth:`do` method receives arguments from both the main |
|
353
|
|
|
loop (e.g. a batch) and the user, it often has to separate them. |
|
354
|
|
|
This method is the right tool to use. |
|
355
|
|
|
|
|
356
|
|
|
Parameters |
|
357
|
|
|
---------- |
|
358
|
|
|
which_callback : str |
|
359
|
|
|
The name of the callback. |
|
360
|
|
|
args : iterable |
|
361
|
|
|
The arguments. |
|
362
|
|
|
|
|
363
|
|
|
Returns |
|
364
|
|
|
------- |
|
365
|
|
|
from_main_loop : tuple |
|
366
|
|
|
from_user : tuple |
|
367
|
|
|
|
|
368
|
|
|
""" |
|
369
|
|
|
args = tuple(args) |
|
370
|
|
|
if (which_callback == 'after_batch' or |
|
371
|
|
|
which_callback == 'before_batch'): |
|
372
|
|
|
return (args[0],), args[1:] |
|
373
|
|
|
return (), args |
|
374
|
|
|
|
|
375
|
|
|
|
|
376
|
|
|
class CompositeExtension(SimpleExtension): |
|
377
|
|
|
"""An extension that manages several other extensions. |
|
378
|
|
|
|
|
379
|
|
|
Parameters |
|
380
|
|
|
---------- |
|
381
|
|
|
sub_extensions : iterable |
|
382
|
|
|
An iterable collection of sub-extensions to manage. |
|
383
|
|
|
run_before_children : bool, optional |
|
384
|
|
|
Whether the container extension's own logic should |
|
385
|
|
|
be dispatched before that of the sub-extensions. |
|
386
|
|
|
If ``False``, the containing extension is dispatched last. |
|
387
|
|
|
Defaults to ``True``. |
|
388
|
|
|
|
|
389
|
|
|
Notes |
|
390
|
|
|
----- |
|
391
|
|
|
The main use case for this class is bundling together groups |
|
392
|
|
|
of extensions that are most commonly used in tandem, configured |
|
393
|
|
|
so as to interact with one another. Encapsulating this pattern |
|
394
|
|
|
in a single extension reduces boilerplate. |
|
395
|
|
|
|
|
396
|
|
|
Sub-extensions are dispatched in the order specified in |
|
397
|
|
|
``sub_extensions``, on whatever triggers they are individually |
|
398
|
|
|
configured to respect. |
|
399
|
|
|
|
|
400
|
|
|
Sub-extensions may be run on different triggers than the containing |
|
401
|
|
|
extension; the trigger keywords passed to the constructor |
|
402
|
|
|
for this class only affect the outer extension's logic, and |
|
403
|
|
|
sub-extensions should be configured independently (possibly in |
|
404
|
|
|
a constructor for a subclass of :class:`CompositeExtension`). |
|
405
|
|
|
|
|
406
|
|
|
""" |
|
407
|
|
|
def __init__(self, sub_extensions, run_before_children=True, **kwargs): |
|
408
|
|
|
self.sub_extensions = sub_extensions |
|
409
|
|
|
self.run_before_children = run_before_children |
|
410
|
|
|
super(CompositeExtension, self).__init__(**kwargs) |
|
411
|
|
|
|
|
412
|
|
|
def dispatch(self, callback_invoked, *from_main_loop): |
|
413
|
|
|
def run_super(): |
|
414
|
|
|
super(CompositeExtension, self).dispatch(callback_invoked, |
|
415
|
|
|
*from_main_loop) |
|
416
|
|
|
if self.run_before_children: |
|
417
|
|
|
run_super() |
|
418
|
|
|
|
|
419
|
|
|
for ext in self.sub_extensions: |
|
420
|
|
|
ext.dispatch(callback_invoked, *from_main_loop) |
|
421
|
|
|
|
|
422
|
|
|
if not self.run_before_children: |
|
423
|
|
|
run_super() |
|
424
|
|
|
|
|
425
|
|
|
@property |
|
426
|
|
|
def main_loop(self): |
|
427
|
|
|
return super(CompositeExtension, self).main_loop |
|
428
|
|
|
|
|
429
|
|
|
@main_loop.setter |
|
430
|
|
|
def main_loop(self, value): |
|
431
|
|
|
self._main_loop = value |
|
432
|
|
|
for sub in self.sub_extensions: |
|
433
|
|
|
sub.main_loop = value |
|
434
|
|
|
|
|
435
|
|
|
def do(self, which_callback, *args): |
|
436
|
|
|
pass |
|
437
|
|
|
|
|
438
|
|
|
|
|
439
|
|
|
class FinishAfter(SimpleExtension): |
|
440
|
|
|
"""Finishes the training process when triggered.""" |
|
441
|
|
|
def __init__(self, **kwargs): |
|
442
|
|
|
super(FinishAfter, self).__init__(**kwargs) |
|
443
|
|
|
|
|
444
|
|
|
def do(self, which_callback, *args): |
|
445
|
|
|
self.main_loop.log.current_row['training_finish_requested'] = True |
|
446
|
|
|
|
|
447
|
|
|
|
|
448
|
|
|
class Printing(SimpleExtension): |
|
449
|
|
|
"""Prints log messages to the screen.""" |
|
450
|
|
|
def __init__(self, **kwargs): |
|
451
|
|
|
kwargs.setdefault("before_first_epoch", True) |
|
452
|
|
|
kwargs.setdefault("on_resumption", True) |
|
453
|
|
|
kwargs.setdefault("after_training", True) |
|
454
|
|
|
kwargs.setdefault("after_epoch", True) |
|
455
|
|
|
kwargs.setdefault("on_interrupt", True) |
|
456
|
|
|
super(Printing, self).__init__(**kwargs) |
|
457
|
|
|
|
|
458
|
|
|
def _print_attributes(self, attribute_tuples): |
|
459
|
|
|
for attr, value in sorted(attribute_tuples.items(), key=first): |
|
460
|
|
|
if not attr.startswith("_"): |
|
461
|
|
|
print("\t", "{}:".format(attr), value) |
|
462
|
|
|
|
|
463
|
|
|
def do(self, which_callback, *args): |
|
464
|
|
|
log = self.main_loop.log |
|
465
|
|
|
print_status = True |
|
466
|
|
|
|
|
467
|
|
|
print() |
|
468
|
|
|
print("".join(79 * "-")) |
|
469
|
|
|
if which_callback == "before_epoch" and log.status['epochs_done'] == 0: |
|
470
|
|
|
print("BEFORE FIRST EPOCH") |
|
471
|
|
|
elif which_callback == "on_resumption": |
|
472
|
|
|
print("TRAINING HAS BEEN RESUMED") |
|
473
|
|
|
elif which_callback == "after_training": |
|
474
|
|
|
print("TRAINING HAS BEEN FINISHED:") |
|
475
|
|
|
elif which_callback == "after_epoch": |
|
476
|
|
|
print("AFTER ANOTHER EPOCH") |
|
477
|
|
|
elif which_callback == "on_interrupt": |
|
478
|
|
|
print("TRAINING HAS BEEN INTERRUPTED") |
|
479
|
|
|
print_status = False |
|
480
|
|
|
print("".join(79 * "-")) |
|
481
|
|
|
if print_status: |
|
482
|
|
|
print("Training status:") |
|
483
|
|
|
self._print_attributes(log.status) |
|
484
|
|
|
print("Log records from the iteration {}:".format( |
|
485
|
|
|
log.status['iterations_done'])) |
|
486
|
|
|
self._print_attributes(log.current_row) |
|
487
|
|
|
print() |
|
488
|
|
|
|
|
489
|
|
|
|
|
490
|
|
|
class ProgressBar(TrainingExtension): |
|
491
|
|
|
"""Display a progress bar during training. |
|
492
|
|
|
|
|
493
|
|
|
This extension tries to infer the number of iterations per epoch |
|
494
|
|
|
by querying the `num_batches`, `num_examples` and `batch_size` |
|
495
|
|
|
attributes from the :class:`IterationScheme`. When this information is |
|
496
|
|
|
not available it will display a simplified progress bar that does not |
|
497
|
|
|
include the estimated time until the end of this epoch. |
|
498
|
|
|
|
|
499
|
|
|
Notes |
|
500
|
|
|
----- |
|
501
|
|
|
This extension should be run before other extensions that print to |
|
502
|
|
|
the screen at the end or at the beginning of the epoch (e.g. the |
|
503
|
|
|
:class:`Printing` extension). Placing ProgressBar before these |
|
504
|
|
|
extension will ensure you won't get intermingled output on your |
|
505
|
|
|
terminal. |
|
506
|
|
|
|
|
507
|
|
|
""" |
|
508
|
|
|
def __init__(self, **kwargs): |
|
509
|
|
|
super(ProgressBar, self).__init__(**kwargs) |
|
510
|
|
|
self.bar = None |
|
511
|
|
|
self.iter_count = 0 |
|
512
|
|
|
|
|
513
|
|
|
def __getstate__(self): |
|
514
|
|
|
# Ensure we won't pickle the actual progress bar. |
|
515
|
|
|
# (It might contain unpicklable file handles) |
|
516
|
|
|
state = dict(self.__dict__) |
|
517
|
|
|
del state['bar'] |
|
518
|
|
|
return state |
|
519
|
|
|
|
|
520
|
|
|
def __setstate__(self, state): |
|
521
|
|
|
self.__dict__.update(state) |
|
522
|
|
|
self.bar = None |
|
523
|
|
|
|
|
524
|
|
|
def get_iter_per_epoch(self): |
|
525
|
|
|
"""Try to infer the number of iterations per epoch.""" |
|
526
|
|
|
iter_scheme = self.main_loop.data_stream.iteration_scheme |
|
527
|
|
|
if (hasattr(iter_scheme, 'indices') and |
|
528
|
|
|
not hasattr(iter_scheme, 'batch_size')): |
|
529
|
|
|
# schemes that extend IndexScheme |
|
530
|
|
|
return len(iter_scheme.indices) |
|
531
|
|
|
elif (hasattr(iter_scheme, 'indices') and |
|
532
|
|
|
hasattr(iter_scheme, 'batch_size')): |
|
533
|
|
|
# schemes that extend BatchScheme |
|
534
|
|
|
return len(iter_scheme.indices) // iter_scheme.batch_size |
|
535
|
|
|
elif (hasattr(iter_scheme, 'num_examples') and |
|
536
|
|
|
hasattr(iter_scheme, 'batch_size')): |
|
537
|
|
|
# ConstantScheme |
|
538
|
|
|
return iter_scheme.num_examples // iter_scheme.batch_size |
|
539
|
|
|
return None |
|
540
|
|
|
|
|
541
|
|
|
def create_bar(self): |
|
542
|
|
|
"""Create a new progress bar. |
|
543
|
|
|
|
|
544
|
|
|
Calls `self.get_iter_per_epoch()`, selects an appropriate |
|
545
|
|
|
set of widgets and creates a ProgressBar. |
|
546
|
|
|
|
|
547
|
|
|
""" |
|
548
|
|
|
iter_per_epoch = self.get_iter_per_epoch() |
|
549
|
|
|
epochs_done = self.main_loop.log.status['epochs_done'] |
|
550
|
|
|
|
|
551
|
|
|
if iter_per_epoch is None: |
|
552
|
|
|
widgets = ["Epoch {}, step ".format(epochs_done), |
|
553
|
|
|
progressbar.Counter(), ' ', |
|
554
|
|
|
progressbar.BouncingBar(), ' ', |
|
555
|
|
|
progressbar.Timer()] |
|
556
|
|
|
iter_per_epoch = progressbar.UnknownLength |
|
557
|
|
|
else: |
|
558
|
|
|
widgets = ["Epoch {}, step ".format(epochs_done), |
|
559
|
|
|
progressbar.Counter(), |
|
560
|
|
|
' (', progressbar.Percentage(), ') ', |
|
561
|
|
|
progressbar.Bar(), ' ', |
|
562
|
|
|
progressbar.Timer(), ' ', progressbar.ETA()] |
|
563
|
|
|
|
|
564
|
|
|
return progressbar.ProgressBar(widgets=widgets, |
|
565
|
|
|
max_value=iter_per_epoch) |
|
566
|
|
|
|
|
567
|
|
|
def before_epoch(self): |
|
568
|
|
|
self.iter_count = 0 |
|
569
|
|
|
|
|
570
|
|
|
def after_epoch(self): |
|
571
|
|
|
if self.bar is None: |
|
572
|
|
|
return |
|
573
|
|
|
|
|
574
|
|
|
self.bar.finish() |
|
575
|
|
|
self.bar = None |
|
576
|
|
|
|
|
577
|
|
|
def before_batch(self, batch): |
|
578
|
|
|
if self.bar is None: |
|
579
|
|
|
self.bar = self.create_bar() |
|
580
|
|
|
self.bar.start() |
|
581
|
|
|
|
|
582
|
|
|
self.iter_count += 1 |
|
583
|
|
|
self.bar.update(self.iter_count) |
|
584
|
|
|
|
|
585
|
|
|
|
|
586
|
|
|
class Timing(SimpleExtension): |
|
587
|
|
|
"""Add timing information to the log. |
|
588
|
|
|
|
|
589
|
|
|
This adds data about the time spent in the algorithm's |
|
590
|
|
|
:meth:`~.Algorithm.process_batch` method as well as the time spent |
|
591
|
|
|
reading data per batch or epoch. It also reports the time spent |
|
592
|
|
|
initializing the algorithm. |
|
593
|
|
|
|
|
594
|
|
|
Parameters |
|
595
|
|
|
---------- |
|
596
|
|
|
prefix : str |
|
597
|
|
|
Prefix to be added to the log record. Defaults to the empty string. |
|
598
|
|
|
|
|
599
|
|
|
Notes |
|
600
|
|
|
----- |
|
601
|
|
|
Add this extension *before* the :class:`Printing` extension. |
|
602
|
|
|
|
|
603
|
|
|
Created with callbacks like ``every_n_batches`` this extension |
|
604
|
|
|
averages the time. |
|
605
|
|
|
|
|
606
|
|
|
This extension does *not* enable full profiling information. To see a |
|
607
|
|
|
full profile of the main loop at the end of training, use the |
|
608
|
|
|
``profile`` configuration (e.g. by setting ``BLOCKS_PROFILE=true``). |
|
609
|
|
|
|
|
610
|
|
|
""" |
|
611
|
|
|
def __init__(self, prefix="", **kwargs): |
|
612
|
|
|
kwargs.setdefault('before_first_epoch', True) |
|
613
|
|
|
kwargs.setdefault('after_epoch', True) |
|
614
|
|
|
super(Timing, self).__init__(**kwargs) |
|
615
|
|
|
|
|
616
|
|
|
def init_dict(): |
|
617
|
|
|
return { |
|
618
|
|
|
level: {'train': 0, 'read_data': 0} |
|
619
|
|
|
for level in ['batch', 'epoch']} |
|
620
|
|
|
self.current = init_dict() |
|
621
|
|
|
self.previous = init_dict() |
|
622
|
|
|
self.current_index = init_dict() |
|
623
|
|
|
self.previous_index = init_dict() |
|
624
|
|
|
self.prefix = prefix |
|
625
|
|
|
if self.prefix: |
|
626
|
|
|
self.prefix += '_' |
|
627
|
|
|
|
|
628
|
|
|
def do(self, which_callback, *args): |
|
629
|
|
|
current_row = self.main_loop.log.current_row |
|
630
|
|
|
profile = self.main_loop.profile.total |
|
631
|
|
|
|
|
632
|
|
|
if which_callback == 'before_epoch': |
|
633
|
|
|
current_row['time_initialization'] = profile[('initialization',)] |
|
634
|
|
|
return |
|
635
|
|
|
if which_callback == 'after_batch': |
|
636
|
|
|
level = 'batch' |
|
637
|
|
|
counter = 'iterations_done' |
|
638
|
|
|
elif which_callback == 'after_epoch': |
|
639
|
|
|
level = 'epoch' |
|
640
|
|
|
counter = 'epochs_done' |
|
641
|
|
|
else: |
|
642
|
|
|
raise ValueError('wrong callback type `{}`'.format(which_callback)) |
|
643
|
|
|
for action in ['train', 'read_data']: |
|
644
|
|
|
self.previous_index[level][action] = ( |
|
645
|
|
|
self.current_index[level][action]) |
|
646
|
|
|
self.current_index[level][action] = ( |
|
647
|
|
|
self.main_loop.log.status[counter]) |
|
648
|
|
|
current_index = self.current_index[level][action] |
|
649
|
|
|
previous_index = self.previous_index[level][action] |
|
650
|
|
|
if current_index == previous_index: |
|
651
|
|
|
logger.debug('Timing extension was called twice this %s, ' |
|
652
|
|
|
'log was not updated.', level) |
|
653
|
|
|
# Nothing to report for this level |
|
654
|
|
|
continue |
|
655
|
|
|
|
|
656
|
|
|
self.previous[level][action] = self.current[level][action] |
|
657
|
|
|
self.current[level][action] = profile['training', 'epoch', action] |
|
658
|
|
|
|
|
659
|
|
|
this_time = self.prefix + 'time_{}_this_{}' |
|
660
|
|
|
current_row[this_time.format(action, level)] = ( |
|
661
|
|
|
(self.current[level][action] - self.previous[level][action]) / |
|
662
|
|
|
(current_index - previous_index)) |
|
663
|
|
|
total_time = self.prefix + 'time_{}_total' |
|
664
|
|
|
current_row[total_time.format(action)] = \ |
|
665
|
|
|
self.current[level][action] |
|
666
|
|
|
|
|
667
|
|
|
|
|
668
|
|
|
class Timestamp(SimpleExtension): |
|
669
|
|
|
"""Adds a human readable (ISO 8601) timestamp to the log. |
|
670
|
|
|
|
|
671
|
|
|
Parameters |
|
672
|
|
|
---------- |
|
673
|
|
|
log_record : str, optional |
|
674
|
|
|
The record name to use. Defaults to 'timestamp'. |
|
675
|
|
|
separator : str, optional |
|
676
|
|
|
Separator between the date and time. ISO 8601 specifies 'T'. |
|
677
|
|
|
Here, we default to ' ' (blank space) for human readability. |
|
678
|
|
|
|
|
679
|
|
|
Notes |
|
680
|
|
|
----- |
|
681
|
|
|
By default, triggers after every epoch as well as before training |
|
682
|
|
|
starts, after training finishes, when an error occurs or when training |
|
683
|
|
|
is interrupted or resumed, as these are all generally useful |
|
684
|
|
|
circumstances for which to have a timestamp. These can be disabled |
|
685
|
|
|
by passing `False` as the appropriate keyword argument; see |
|
686
|
|
|
:class:`SimpleExtension`. |
|
687
|
|
|
|
|
688
|
|
|
""" |
|
689
|
|
|
DEFAULT_LOG_RECORD = 'timestamp' |
|
690
|
|
|
|
|
691
|
|
|
def __init__(self, log_record=DEFAULT_LOG_RECORD, separator=' ', |
|
692
|
|
|
**kwargs): |
|
693
|
|
|
self.log_record = log_record |
|
694
|
|
|
self.separator = separator |
|
695
|
|
|
default_callbacks = ['before_training', 'after_epoch', 'on_error', |
|
696
|
|
|
'on_interrupt', 'on_resumption', 'after_training'] |
|
697
|
|
|
for callback in default_callbacks: |
|
698
|
|
|
kwargs.setdefault(callback, True) |
|
699
|
|
|
super(Timestamp, self).__init__(**kwargs) |
|
700
|
|
|
|
|
701
|
|
|
def do(self, *args): |
|
702
|
|
|
self.main_loop.log.current_row[self.log_record] = self.get_timestamp() |
|
703
|
|
|
|
|
704
|
|
|
def get_timestamp(self): |
|
705
|
|
|
# Separated into a method to override for ease of testing. |
|
706
|
|
|
return datetime.datetime.now().isoformat(self.separator) |
|
707
|
|
|
|