1 | """Extensions for monitoring the training process.""" |
||
2 | import logging |
||
3 | |||
4 | import theano |
||
5 | |||
6 | from blocks.extensions import SimpleExtension, TrainingExtension |
||
7 | from blocks.algorithms import UpdatesAlgorithm |
||
8 | from blocks.monitoring.aggregation import MonitoredQuantity, take_last |
||
9 | from blocks.monitoring.evaluators import ( |
||
10 | AggregationBuffer, MonitoredQuantityBuffer, DatasetEvaluator) |
||
11 | |||
12 | SEPARATOR = '_' |
||
13 | logger = logging.getLogger(__name__) |
||
14 | |||
15 | |||
16 | class MonitoringExtension(TrainingExtension): |
||
17 | """A mixin with logic shared by monitoring extensions. |
||
18 | |||
19 | Parameters |
||
20 | ---------- |
||
21 | prefix : str, optional |
||
22 | The prefix for the log records done by the extension. It is |
||
23 | prepended to the variable names with an underscore as a separator. |
||
24 | If not given, no prefix is added to the names of the observed |
||
25 | variables. |
||
26 | suffix : str, optional |
||
27 | The suffix for the log records done by the extension. It is |
||
28 | appended to the end of variable names with an underscore as a |
||
29 | separator. If not given, no suffix is added the names of the |
||
30 | observed variables. |
||
31 | |||
32 | """ |
||
33 | SEPARATOR = SEPARATOR |
||
34 | |||
35 | def __init__(self, prefix=None, suffix=None, **kwargs): |
||
36 | super(MonitoringExtension, self).__init__(**kwargs) |
||
37 | self.prefix = prefix |
||
38 | self.suffix = suffix |
||
39 | |||
40 | def _record_name(self, name): |
||
41 | """The record name for a variable name.""" |
||
42 | if not isinstance(name, str): |
||
43 | raise ValueError("record name must be a string") |
||
44 | |||
45 | return self.SEPARATOR.join( |
||
46 | [morpheme for morpheme in [self.prefix, name, self.suffix] |
||
47 | if morpheme is not None]) |
||
48 | |||
49 | def record_name(self, variable): |
||
50 | """The record name for a variable.""" |
||
51 | return self._record_name(variable.name) |
||
52 | |||
53 | def add_records(self, log, record_tuples): |
||
54 | """Helper function to add monitoring records to the log.""" |
||
55 | for name, value in record_tuples: |
||
56 | if not name: |
||
57 | raise ValueError("monitor variable without name") |
||
58 | log.current_row[self._record_name(name)] = value |
||
59 | |||
60 | |||
61 | class DataStreamMonitoring(SimpleExtension, MonitoringExtension): |
||
62 | """Monitors Theano variables and monitored-quantities on a data stream. |
||
63 | |||
64 | By default monitoring is done before the first and after every epoch. |
||
65 | |||
66 | Parameters |
||
67 | ---------- |
||
68 | variables : list of :class:`~tensor.TensorVariable` and |
||
69 | :class:`MonitoredQuantity` |
||
70 | The variables to monitor. The variable names are used as record |
||
71 | names in the logs. |
||
72 | updates : list of tuples or :class:`~collections.OrderedDict` or None |
||
73 | :class:`~tensor.TensorSharedVariable` updates to be performed |
||
74 | during evaluation. This parameter is only for Theano variables. |
||
75 | Be careful not to update any model parameters as this is not |
||
76 | intended to alter your model in any meaningful way. A typical |
||
77 | use case of this option arises when the theano function used |
||
78 | for evaluation contains a call to :func:`~theano.scan` which |
||
79 | might have returned shared variable updates. |
||
80 | data_stream : instance of :class:`.DataStream` |
||
81 | The data stream to monitor on. A data epoch is requested |
||
82 | each time monitoring is done. |
||
83 | |||
84 | """ |
||
85 | def __init__(self, variables, data_stream, updates=None, **kwargs): |
||
86 | kwargs.setdefault("after_epoch", True) |
||
87 | kwargs.setdefault("before_first_epoch", True) |
||
88 | super(DataStreamMonitoring, self).__init__(**kwargs) |
||
89 | self._evaluator = DatasetEvaluator(variables, updates) |
||
90 | self.data_stream = data_stream |
||
91 | |||
92 | def do(self, callback_name, *args): |
||
0 ignored issues
–
show
Unused Code
introduced
by
Loading history...
|
|||
93 | """Write the values of monitored variables to the log.""" |
||
94 | logger.info("Monitoring on auxiliary data started") |
||
95 | value_dict = self._evaluator.evaluate(self.data_stream) |
||
96 | self.add_records(self.main_loop.log, value_dict.items()) |
||
97 | logger.info("Monitoring on auxiliary data finished") |
||
98 | |||
99 | |||
100 | class TrainingDataMonitoring(SimpleExtension, MonitoringExtension): |
||
101 | """Monitors values of Theano variables on training batches. |
||
102 | |||
103 | Use this extension to monitor a quantity on every training batch |
||
104 | cheaply. It integrates with the training algorithm in order to avoid |
||
105 | recomputing same things several times. For instance, if you are |
||
106 | training a network and you want to log the norm of the gradient on |
||
107 | every batch, the backpropagation will only be done once. By |
||
108 | controlling the frequency with which the :meth:`do` method is called, |
||
109 | you can aggregate the monitored variables, e.g. only log the gradient |
||
110 | norm average over an epoch. |
||
111 | |||
112 | Parameters |
||
113 | ---------- |
||
114 | variables : list of :class:`~tensor.TensorVariable` or |
||
115 | :class:`~blocks.monitoring.aggregation.MonitoredQuantity` |
||
116 | The variables or non-Theano quantities to monitor. |
||
117 | The variable names are used as record names in the logs. |
||
118 | |||
119 | Notes |
||
120 | ----- |
||
121 | All the monitored variables are evaluated _before_ the parameter |
||
122 | update. |
||
123 | |||
124 | Requires the training algorithm to be an instance of |
||
125 | :class:`.UpdatesAlgorithm`. |
||
126 | |||
127 | """ |
||
128 | def __init__(self, variables, **kwargs): |
||
129 | kwargs.setdefault("before_training", True) |
||
130 | super(TrainingDataMonitoring, self).__init__(**kwargs) |
||
131 | self.add_condition(['after_batch'], arguments=('just_aggregate',)) |
||
132 | |||
133 | self._non_variables = [] |
||
134 | self._variables = [] |
||
135 | for variable_or_not in variables: |
||
136 | if isinstance(variable_or_not, theano.Variable): |
||
137 | self._variables.append(variable_or_not) |
||
138 | elif isinstance(variable_or_not, MonitoredQuantity): |
||
139 | self._non_variables.append(variable_or_not) |
||
140 | else: |
||
141 | raise ValueError("can not monitor {}".format(variable_or_not)) |
||
142 | |||
143 | self._non_variables = MonitoredQuantityBuffer(self._non_variables) |
||
144 | self._required_for_non_variables = AggregationBuffer( |
||
145 | [take_last(v) for v in self._non_variables.requires]) |
||
146 | self._variables = AggregationBuffer( |
||
147 | self._variables, use_take_last=True) |
||
148 | self._last_time_called = -1 |
||
149 | |||
150 | def do(self, callback_name, *args): |
||
151 | """Initializes the buffer or commits the values to the log. |
||
152 | |||
153 | What this method does depends on from what callback it is called |
||
154 | and with which arguments. When called within `before_training`, it |
||
155 | initializes the aggregation buffer and instructs the training |
||
156 | algorithm what additional computations should be carried at each |
||
157 | step by adding corresponding updates to it. In most_other cases it |
||
158 | writes aggregated values of the monitored variables to the log. An |
||
159 | exception is when an argument `just_aggregate` is given: in this |
||
160 | cases it updates the values of monitored non-Theano quantities, but |
||
161 | does not write anything to the log. |
||
162 | |||
163 | """ |
||
164 | data, args = self.parse_args(callback_name, args) |
||
0 ignored issues
–
show
|
|||
165 | if callback_name == 'before_training': |
||
166 | if not isinstance(self.main_loop.algorithm, |
||
167 | UpdatesAlgorithm): |
||
168 | raise ValueError |
||
169 | self.main_loop.algorithm.add_updates( |
||
170 | self._variables.accumulation_updates) |
||
171 | self.main_loop.algorithm.add_updates( |
||
172 | self._required_for_non_variables.accumulation_updates) |
||
173 | self._variables.initialize_aggregators() |
||
174 | self._required_for_non_variables.initialize_aggregators() |
||
175 | self._non_variables.initialize_quantities() |
||
176 | else: |
||
177 | # When called first time at any iterations, update |
||
178 | # monitored non-Theano quantities |
||
179 | if (self.main_loop.status['iterations_done'] > |
||
180 | self._last_time_called): |
||
181 | self._non_variables.aggregate_quantities( |
||
182 | list(self._required_for_non_variables |
||
183 | .get_aggregated_values().values())) |
||
184 | self._required_for_non_variables.initialize_aggregators() |
||
185 | self._last_time_called = ( |
||
186 | self.main_loop.status['iterations_done']) |
||
187 | # If only called to update non-Theano quantities, |
||
188 | # do just that |
||
189 | if args == ('just_aggregate',): |
||
190 | return |
||
191 | # Otherwise, also output current values of from the accumulators |
||
192 | # to the log. |
||
193 | self.add_records( |
||
194 | self.main_loop.log, |
||
195 | self._variables.get_aggregated_values().items()) |
||
196 | self._variables.initialize_aggregators() |
||
197 | self.add_records( |
||
198 | self.main_loop.log, |
||
199 | self._non_variables.get_aggregated_values().items()) |
||
200 | self._non_variables.initialize_quantities() |
||
201 |