1 | import inspect |
||
0 ignored issues
–
show
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).
Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.
Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.wrappers -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).
Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.
Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).
Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.
Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.misc -> blocks.bricks.parallel -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).
Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.
Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.architectures -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).
Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.
Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.misc -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).
Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.
Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.architectures -> blocks.bricks.recurrent.base -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).
Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.
Loading history...
|
|||
2 | import logging |
||
3 | from blocks.extensions import SimpleExtension |
||
4 | |||
5 | |||
6 | logger = logging.getLogger(__name__) |
||
7 | |||
8 | |||
9 | class SharedVariableModifier(SimpleExtension): |
||
10 | """Adjusts shared variable parameter using some function. |
||
11 | |||
12 | Applies a function to compute the new value of a shared parameter each |
||
13 | iteration. |
||
14 | |||
15 | This class can be used to adapt over the training process parameters |
||
16 | like learning rate, momentum, etc. |
||
17 | |||
18 | Parameters |
||
19 | ---------- |
||
20 | parameter : :class:`~tensor.TensorSharedVariable` |
||
21 | Shared variable to be adjusted |
||
22 | function : callable |
||
23 | A function which outputs a numeric value to which the |
||
24 | given shared variable will be set and may take one or two |
||
25 | arguments. |
||
26 | |||
27 | In the first case, function that takes the total number of |
||
28 | iterations done (``int``) as an input. |
||
29 | |||
30 | In the second case, it is a function which takes number of |
||
31 | iterations done (``int``) and old value of the shared variable |
||
32 | (with the same dtype as `parameter`). |
||
33 | num_args : int, optional |
||
34 | The number of arguments to pass to the function. If unspecified, |
||
35 | it will be inferred. This is useful if you are using function-like |
||
36 | objects for which the arity of the function cannot be inferred. |
||
37 | |||
38 | Notes |
||
39 | ----- |
||
40 | This class includes a method ``function`` that calls the function |
||
41 | passed in the constructor and a ``num_args`` property which computes |
||
42 | the number of arguments to use by inspecting the function object. |
||
43 | Subclasses may override a method called ``function`` and/or |
||
44 | the ``num_args`` property and instead pass ``None`` to the superclass |
||
45 | constructor. This can be used to bypass certain serialization issues |
||
46 | on Legacy Python regarding the unpicklability of instance |
||
47 | method objects. |
||
48 | |||
49 | """ |
||
50 | def __init__(self, parameter, function, num_args=None, **kwargs): |
||
51 | kwargs.setdefault("after_batch", True) |
||
52 | super(SharedVariableModifier, self).__init__(**kwargs) |
||
53 | self.parameter = parameter |
||
54 | self._function = function |
||
55 | self._num_args = num_args |
||
56 | |||
57 | @property |
||
58 | def num_args(self): |
||
59 | if self._num_args is None: |
||
60 | self._num_args = len(inspect.getargspec(self._function).args) |
||
61 | return self._num_args |
||
62 | |||
63 | def function(self, *args): |
||
64 | return self._function(*args) |
||
65 | |||
66 | def do(self, which_callback, *args): |
||
67 | iterations_done = self.main_loop.log.status['iterations_done'] |
||
68 | if self.num_args == 1: |
||
69 | new_value = self.function(iterations_done) |
||
70 | else: |
||
71 | old_value = self.parameter.get_value() |
||
72 | new_value = self.function(iterations_done, old_value) |
||
73 | self.parameter.set_value(new_value) |
||
74 | |||
75 | |||
76 | class TrackTheBest(SimpleExtension): |
||
77 | """Check if a log quantity has the minimum/maximum value so far. |
||
78 | |||
79 | Parameters |
||
80 | ---------- |
||
81 | record_name : str |
||
82 | The name of the record to track. |
||
83 | notification_name : str, optional |
||
84 | The name for the record to be made in the log when the current |
||
85 | value of the tracked quantity is the best so far. It not given, |
||
86 | 'record_name' plus "best_so_far" suffix is used. |
||
87 | choose_best : callable, optional |
||
88 | A function that takes the current value and the best so far |
||
89 | and return the best of two. By default :func:`min`, which |
||
90 | corresponds to tracking the minimum value. |
||
91 | |||
92 | Attributes |
||
93 | ---------- |
||
94 | best_name : str |
||
95 | The name of the status record to keep the best value so far. |
||
96 | notification_name : str |
||
97 | The name of the record written to the log when the current |
||
98 | value of the tracked quantity is the best so far. |
||
99 | |||
100 | Notes |
||
101 | ----- |
||
102 | In the likely case that you are relying on another extension to |
||
103 | add the tracked quantity to the log, make sure to place this |
||
104 | extension *after* the extension that writes the quantity to the log |
||
105 | in the `extensions` argument to :class:`blocks.main_loop.MainLoop`. |
||
106 | |||
107 | """ |
||
108 | def __init__(self, record_name, notification_name=None, |
||
109 | choose_best=min, **kwargs): |
||
110 | self.record_name = record_name |
||
111 | if not notification_name: |
||
112 | notification_name = record_name + "_best_so_far" |
||
113 | self.notification_name = notification_name |
||
114 | self.best_name = "best_" + record_name |
||
115 | self.choose_best = choose_best |
||
116 | kwargs.setdefault("after_epoch", True) |
||
117 | super(TrackTheBest, self).__init__(**kwargs) |
||
118 | |||
119 | def do(self, which_callback, *args): |
||
120 | clsname = self.__class__.__name__ |
||
121 | current_value = self.main_loop.log.current_row.get(self.record_name) |
||
122 | logger.debug('%s: current value of log.current_row["%s"] = %s', |
||
123 | clsname, self.record_name, str(current_value)) |
||
124 | if current_value is None: |
||
125 | return |
||
126 | best_value = self.main_loop.status.get(self.best_name, None) |
||
127 | logger.debug('%s: current value of status["%s"] = %s', |
||
128 | clsname, self.best_name, str(best_value)) |
||
129 | if (best_value is None or |
||
130 | (current_value != best_value and |
||
131 | self.choose_best(current_value, best_value) == |
||
132 | current_value)): |
||
133 | logger.debug('%s: New best obtained at iteration %d!', |
||
134 | clsname, self.main_loop.log.status['iterations_done']) |
||
135 | logger.debug('%s: Updating status["%s"], adding notification ' |
||
136 | 'to log (%s)', clsname, self.best_name, |
||
137 | self.notification_name) |
||
138 | self.main_loop.status[self.best_name] = current_value |
||
139 | self.main_loop.log.current_row[self.notification_name] = True |
||
140 |
Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.