1 | """Blocks native serialization - tar files with pickles and numpy arrays. |
||
2 | |||
3 | This module provides :func:`load` and :func:`dump` functions that can serve |
||
4 | as drop-in replacement for the respective functions from the standard |
||
5 | :mod:`pickle` module. The main differences between them and the standard |
||
6 | ones are: |
||
7 | |||
8 | - The dump is physically a tarball, in which the pickle is stored |
||
9 | as '_pkl' file. |
||
10 | |||
11 | - A special file '_parameters' in the tarball can contain the data |
||
12 | of a selected set of Theano shared variables. This data is |
||
13 | referenced from `_pkl` using persistent id mechanism, which means |
||
14 | that no duplication takes place. The goal here is to save the values |
||
15 | of the parameters (this is what these shared variables are in most |
||
16 | cases) in the most robust way possible. The actual format for |
||
17 | '_parameters' file is the one used by :func:`numpy.savez`, i.e. a zip |
||
18 | file of numpy arrays. |
||
19 | |||
20 | - More objects can be dumped in the archive using the `add_to_dump` |
||
21 | function. If the object has the same parameters as the one already |
||
22 | dumped, then you can avoid to dump those parameters thank to the |
||
23 | persistent id mechanism. |
||
24 | |||
25 | - The :func:`dump` strives to catch situations when the user tries |
||
26 | to pickle a function or a class not defined in the global namespace |
||
27 | and give a meaningful warning. |
||
28 | |||
29 | If briefly, this module proposes a dumping mechanism which allows for |
||
30 | greater robustness and persistence than standard pickling. |
||
31 | |||
32 | Examples |
||
33 | -------- |
||
34 | Consider a standard main loop (without an algorithm and a data stream |
||
35 | for brevity) |
||
36 | |||
37 | >>> from theano import tensor |
||
38 | >>> from blocks.main_loop import MainLoop |
||
39 | >>> from blocks.bricks import MLP, Tanh, Softmax |
||
40 | >>> from blocks.model import Model |
||
41 | >>> mlp = MLP([Tanh(), None], [784, 10, 10]) |
||
42 | >>> x = tensor.matrix('features') |
||
43 | >>> y = tensor.lmatrix('targets') |
||
44 | >>> cost = Softmax().categorical_cross_entropy( |
||
45 | ... y.flatten(), mlp.apply(tensor.flatten(x, outdim=2))) |
||
46 | >>> main_loop = MainLoop(None, None, model=Model(cost)) |
||
47 | |||
48 | Let's see how the main loop is dumped by :func:`dump` |
||
49 | |||
50 | >>> from blocks.serialization import dump, load |
||
51 | >>> import tarfile |
||
52 | >>> with open('main_loop.tar', 'wb') as dst: |
||
53 | ... dump(main_loop, dst) |
||
54 | >>> tarball = tarfile.open('main_loop.tar', 'r') |
||
55 | >>> tarball # doctest: +ELLIPSIS |
||
56 | <tarfile.TarFile object at ...> |
||
57 | >>> tarball.getnames() |
||
58 | ['_pkl'] |
||
59 | >>> tarball.close() |
||
60 | |||
61 | As promised, the dump is a tarball. Since we did not ask for any additional |
||
62 | magic, it just contains the pickled main loop in '_pkl' file. |
||
63 | |||
64 | Let's do something more interesting: |
||
65 | |||
66 | >>> with open('main_loop.tar', 'wb') as dst: |
||
67 | ... dump(main_loop, dst, |
||
68 | ... parameters=main_loop.model.parameters) |
||
69 | >>> tarball = tarfile.open('main_loop.tar', 'r') |
||
70 | >>> tarball.getnames() |
||
71 | ['_parameters', '_pkl'] |
||
72 | |||
73 | As requested by specifying the `_parameters` argument, the parameters were |
||
74 | saved in a zip file. |
||
75 | |||
76 | >>> import numpy |
||
77 | >>> ps = numpy.load(tarball.extractfile(tarball.getmember('_parameters'))) |
||
78 | >>> sorted(ps.keys()) # doctest: +ELLIPSIS |
||
79 | ['|mlp|linear_0.W', '|mlp|linear_0.b', '|mlp|linear_1.W', '|mlp|lin...] |
||
80 | >>> ps.close() |
||
81 | |||
82 | The names for parameters are chosen intelligently to reflect their |
||
83 | position in the brick hierarchy, if they belong to bricks, and by |
||
84 | simply using the `.name` attribute, if they do not. |
||
85 | |||
86 | The loading of the main loop as a whole still works: |
||
87 | |||
88 | >>> with open('main_loop.tar', 'rb') as src: |
||
89 | ... main_loop_loaded = load(src) |
||
90 | >>> main_loop_loaded # doctest: +ELLIPSIS |
||
91 | <blocks.main_loop.MainLoop object at ...> |
||
92 | |||
93 | Additionally, this module provides convenience routine |
||
94 | :func:`load_parameters`: |
||
95 | |||
96 | >>> with open('main_loop.tar', 'rb') as src: |
||
97 | ... parameters = load_parameters(src) |
||
98 | >>> sorted(parameters.keys()) # doctest: +ELLIPSIS |
||
99 | ['/mlp/linear_0.W', '/mlp/linear_0.b', '/mlp/linear_1.W', '/mlp/line...] |
||
100 | |||
101 | Loading parameters saved by :func:`dump` with :func:`load_parameters` |
||
102 | ensures that their hierarchical names are compatible with |
||
103 | :class:`~blocks.model.Model` and :class:`~blocks.select.Selector` classes. |
||
104 | |||
105 | TODO: Add information about :func:`add_to_dump`. |
||
106 | |||
107 | """ |
||
108 | import numpy |
||
109 | import os |
||
110 | import pickle |
||
111 | import shutil |
||
112 | import six |
||
113 | import tarfile |
||
114 | import tempfile |
||
115 | import warnings |
||
116 | import logging |
||
117 | from contextlib import closing |
||
118 | from pickle import HIGHEST_PROTOCOL |
||
119 | try: |
||
120 | from pickle import DEFAULT_PROTOCOL |
||
121 | from pickle import _Pickler |
||
122 | except ImportError: |
||
123 | DEFAULT_PROTOCOL = HIGHEST_PROTOCOL |
||
124 | from pickle import Pickler as _Pickler |
||
125 | |||
126 | from six.moves import cPickle |
||
127 | import theano |
||
128 | try: |
||
129 | from theano.sandbox.cuda import cuda_ndarray |
||
130 | except Exception: |
||
131 | cuda_ndarray = None |
||
0 ignored issues
–
show
|
|||
132 | try: |
||
133 | import pygpu |
||
134 | except Exception: |
||
135 | pygpu = None |
||
136 | from blocks.config import config |
||
137 | from blocks.filter import get_brick |
||
138 | from blocks.utils import change_recursion_limit |
||
139 | from blocks.bricks.base import BRICK_DELIMITER |
||
140 | |||
141 | |||
142 | logger = logging.getLogger(__name__) |
||
143 | |||
144 | SERIALIZATION_BRICK_DELIMITER = '|' |
||
145 | MAIN_MODULE_WARNING = """WARNING: Main loop depends on the function `{}` in \ |
||
146 | `__main__` namespace. |
||
147 | |||
148 | Because of limitations to pickling, this means that you will not be able to \ |
||
149 | resume your model outside of a namespace containing this function. In other \ |
||
150 | words, you can only call `continue_training` from within this script.""" |
||
151 | |||
152 | |||
153 | def dump(object_, file_, parameters=None, use_cpickle=False, |
||
154 | protocol=DEFAULT_PROTOCOL, **kwargs): |
||
155 | r"""Pickles an object, optionally saving its parameters separately. |
||
156 | |||
157 | Parameters |
||
158 | ---------- |
||
159 | object_ : object |
||
160 | The object to pickle. If None, only the parameters passed to the |
||
161 | `parameters` argument will be saved. |
||
162 | file_ : file |
||
163 | The destination for saving. |
||
164 | parameters : list, optional |
||
165 | Shared variables whose internal numpy arrays should be saved |
||
166 | separately in the `_parameters` field of the tar file. |
||
167 | pickle_object : bool |
||
168 | If False, `object_` will not be serialized, only its parameters. |
||
169 | This flag can be used when `object_` is not serializable, but one |
||
170 | still want to save its parameters. Default: True |
||
171 | use_cpickle : bool |
||
172 | Use cPickle instead of pickle. Setting it to true will disable the |
||
173 | warning message if you try to pickle objects from the main module, |
||
174 | so be sure that there is no warning before turning this flag |
||
175 | on. Default: False. |
||
176 | protocol : int, optional |
||
177 | The pickling protocol to use. Unlike Python's built-in pickle, the |
||
178 | default is set to `2` instead of 0 for Python 2. The Python 3 |
||
179 | default (level 3) is maintained. |
||
180 | \*\*kwargs |
||
181 | Keyword arguments to be passed to `pickle.Pickler`. |
||
182 | |||
183 | """ |
||
184 | if use_cpickle: |
||
185 | pickler = cPickle.Pickler |
||
186 | else: |
||
187 | pickler = _PicklerWithWarning |
||
188 | with closing(tarfile.TarFile(fileobj=file_, mode='w')) as tar_file: |
||
189 | external_objects = {} |
||
190 | |||
191 | def _save_parameters(f): |
||
192 | renamer = _Renamer() |
||
193 | named_parameters = {renamer(p): p for p in parameters} |
||
194 | numpy.savez(f, **{n: p.get_value() |
||
195 | for n, p in named_parameters.items()}) |
||
196 | for name, p in named_parameters.items(): |
||
197 | array_ = p.container.storage[0] |
||
198 | external_objects[id(array_)] = _mangle_parameter_name(p, name) |
||
199 | if parameters: |
||
200 | _taradd(_save_parameters, tar_file, '_parameters') |
||
201 | if object_ is not None: |
||
202 | save_object = _SaveObject(pickler, object_, external_objects, |
||
203 | protocol, **kwargs) |
||
204 | _taradd(save_object, tar_file, '_pkl') |
||
205 | |||
206 | |||
207 | def secure_dump(object_, path, dump_function=dump, **kwargs): |
||
208 | r"""Robust serialization - does not corrupt your files when failed. |
||
209 | |||
210 | Parameters |
||
211 | ---------- |
||
212 | object_ : object |
||
213 | The object to be saved to the disk. |
||
214 | path : str |
||
215 | The destination for saving. |
||
216 | dump_function : function |
||
217 | The function that is used to perform the serialization. Must take |
||
218 | an object and file object as arguments. By default, :func:`dump` is |
||
219 | used. An alternative would be :func:`pickle.dump`. |
||
220 | \*\*kwargs |
||
221 | Keyword arguments to be passed to `dump_function`. |
||
222 | |||
223 | """ |
||
224 | try: |
||
225 | logger.debug("Dumping object to a temporary file") |
||
226 | with tempfile.NamedTemporaryFile(delete=False, |
||
227 | dir=config.temp_dir) as temp: |
||
228 | dump_function(object_, temp, **kwargs) |
||
229 | logger.debug("Moving the temporary file") |
||
230 | shutil.move(temp.name, path) |
||
231 | logger.debug("Dump finished") |
||
232 | except: |
||
233 | if "temp" in locals(): |
||
234 | os.remove(temp.name) |
||
235 | raise |
||
236 | |||
237 | |||
238 | def load(file_, name='_pkl', use_cpickle=False, **kwargs): |
||
239 | r"""Loads an object saved using the `dump` function. |
||
240 | |||
241 | By default, this function loads the object saved by the `dump` |
||
242 | function. If some objects have been added to the archive using the |
||
243 | `add_to_dump` function, then you can load them by passing their name |
||
244 | to the `name` parameter. |
||
245 | |||
246 | Parameters |
||
247 | ---------- |
||
248 | file_ : file |
||
249 | The file that contains the object to load. |
||
250 | name : str |
||
251 | Name of the object to load. Default is `_pkl`, meaning that it is |
||
252 | the original object which have been dumped that is loaded. |
||
253 | use_cpickle : bool |
||
254 | Use cPickle instead of pickle. Default: False. |
||
255 | \*\*kwargs |
||
256 | Keyword arguments to be passed to `pickle.Unpickler`. |
||
257 | Used for e.g. specifying the encoding so as to load legacy Python |
||
258 | pickles under Python 3.x. |
||
259 | |||
260 | Returns |
||
261 | ------- |
||
262 | The object saved in ``file_``. |
||
263 | |||
264 | """ |
||
265 | file_.seek(0) # To be able to read several objects in one file |
||
266 | if use_cpickle: |
||
267 | unpickler = cPickle.Unpickler |
||
268 | else: |
||
269 | unpickler = pickle.Unpickler |
||
270 | with tarfile.open(fileobj=file_, mode='r') as tar_file: |
||
271 | p = unpickler( |
||
272 | tar_file.extractfile(tar_file.getmember(name)), |
||
273 | **kwargs |
||
274 | ) |
||
275 | if '_parameters' in tar_file.getnames(): |
||
276 | p.persistent_load = _PersistentLoad(tar_file) |
||
277 | return p.load() |
||
278 | |||
279 | |||
280 | def load_parameters(file_): |
||
281 | """Loads the parameter values saved by :func:`dump`. |
||
282 | |||
283 | This functions loads the parameters that have been saved separately by |
||
284 | :func:`dump`, ie the ones given to its parameter `parameters`. |
||
285 | |||
286 | Parameters |
||
287 | ---------- |
||
288 | file_ : file |
||
289 | The source to load the parameters from. |
||
290 | |||
291 | Returns |
||
292 | ------- |
||
293 | A dictionary of (parameter name, numpy array) pairs. |
||
294 | |||
295 | """ |
||
296 | with closing(_load_parameters_npzfile(file_)) as npz_file: |
||
297 | return {name.replace(SERIALIZATION_BRICK_DELIMITER, |
||
298 | BRICK_DELIMITER): value |
||
299 | for name, value in npz_file.items()} |
||
300 | |||
301 | |||
302 | def add_to_dump(object_, file_, name, parameters=None, use_cpickle=False, |
||
303 | protocol=DEFAULT_PROTOCOL, **kwargs): |
||
304 | r"""Pickles an object to an existing tar archive. |
||
305 | |||
306 | This function allows to dump more objects to an existing archive. If |
||
307 | the object you want to dump posesses the same set of shared variables |
||
308 | as the object already dumped, you can pass them to the `parameters` |
||
309 | argument, which will avoid them to be serialized a second time. |
||
310 | However, it won't work if the shared variable you pass to the |
||
311 | `parameters` argument are not already in the archive. |
||
312 | |||
313 | Parameters |
||
314 | ---------- |
||
315 | object_ : object |
||
316 | The object to pickle. |
||
317 | file_ : file |
||
318 | The destination for saving, opened in read-write mode (`r+`). |
||
319 | name : str |
||
320 | The name of the object you are dumping. It will be used as a file |
||
321 | name in the archive. '_pkl' and '_paramters' are reserved names |
||
322 | and can't be used. |
||
323 | parameters : list, optional |
||
324 | Shared variables whose internal numpy arrays should be saved |
||
325 | separately in the `_parameters` field of the tar file. Must be a |
||
326 | subset of the parameters already in the archive. |
||
327 | use_cpickle : bool |
||
328 | Use cPickle instead of pickle. Setting it to true will disable the |
||
329 | warning message if you try to pickle objects from the main module! |
||
330 | Be sure that you don't have the warning before turning this flag |
||
331 | on. Default: False. |
||
332 | protocol : int, optional |
||
333 | The pickling protocol to use. Unlike Python's built-in pickle, the |
||
334 | default is set to `2` instead of 0 for Python 2. The Python 3 |
||
335 | default (level 3) is maintained. |
||
336 | \*\*kwargs |
||
337 | Keyword arguments to be passed to `pickle.Pickler`. |
||
338 | |||
339 | """ |
||
340 | if name in ['_pkl', '_parameters']: |
||
341 | raise ValueError("_pkl and _parameters are reserved names and can't" |
||
342 | " be used as name for your object.") |
||
343 | |||
344 | external_parameters = {} |
||
345 | if parameters is not None: |
||
346 | renamer = _Renamer() |
||
347 | named_parameters = {renamer(p): p for p in parameters} |
||
348 | for n, p in named_parameters.items(): |
||
349 | array_ = p.container.storage[0] |
||
350 | external_parameters[id(array_)] = _mangle_parameter_name(p, n) |
||
351 | |||
352 | # Check that the parameters are the same that the ones in the archive. |
||
353 | file_.seek(0) # To be able to read what is in the tar file already. |
||
354 | with closing(tarfile.TarFile(fileobj=file_, mode='r')) as tar_file: |
||
355 | if '_parameters' not in tar_file.getnames(): |
||
356 | raise ValueError("There is no parameters in the archive, so" |
||
357 | " you can't use the argument parameters.") |
||
358 | else: |
||
359 | parameters = numpy.load( |
||
360 | tar_file.extractfile(tar_file.getmember('_parameters'))) |
||
361 | s1 = set(parameters.keys()) |
||
362 | s2 = [_unmangle_parameter_name(x)[2] for x in |
||
363 | external_parameters.values()] |
||
364 | if not s1.issuperset(s2): |
||
365 | raise ValueError('The set of parameters is different' |
||
366 | ' from the one in the archive.') |
||
367 | |||
368 | if use_cpickle: |
||
369 | pickler = cPickle.Pickler |
||
370 | else: |
||
371 | pickler = _PicklerWithWarning |
||
372 | file_.seek(0) # To be able to add new things in the tar file. |
||
373 | with closing(tarfile.TarFile(fileobj=file_, mode='a')) as tar_file: |
||
374 | save_object = _SaveObject(pickler, object_, external_parameters, |
||
375 | protocol, **kwargs) |
||
376 | _taradd(save_object, tar_file, name) |
||
377 | |||
378 | |||
379 | def continue_training(path): |
||
380 | """Continues training using checkpoint. |
||
381 | |||
382 | Parameters |
||
383 | ---------- |
||
384 | path : str |
||
385 | Path to checkpoint. |
||
386 | |||
387 | Notes |
||
388 | ----- |
||
389 | Python picklers can unpickle objects from global namespace only if |
||
390 | they are present in namespace where unpickling happens. Often global |
||
391 | functions are needed for mapping, filtering and other data stream |
||
392 | operations. In a case if the main loop uses global objects and |
||
393 | this function fails with a message like |
||
394 | ``` |
||
395 | AttributeError: 'module' object has no attribute '...' |
||
396 | ``` |
||
397 | it means that you need to import these objects. |
||
398 | |||
399 | Examples |
||
400 | -------- |
||
401 | This function can be used in two ways: in your script where a main |
||
402 | loop defined or in a different script. For later options see Notes |
||
403 | section. |
||
404 | |||
405 | """ |
||
406 | with change_recursion_limit(config.recursion_limit): |
||
407 | with open(path, "rb") as f: |
||
408 | main_loop = load(f) |
||
409 | main_loop.run() |
||
410 | |||
411 | |||
412 | def dump_and_add_to_dump(object_, file_, parameters=None, to_add=None, |
||
413 | use_cpickle=False, protocol=DEFAULT_PROTOCOL, |
||
414 | **kwargs): |
||
415 | r"""Calls both `dump` and `add_to_dump` to serialze several objects. |
||
416 | |||
417 | This function is used to serialize several at the same time, using |
||
418 | persistent ID. Its main advantage is that it can be used with |
||
419 | `secure_dump`. |
||
420 | |||
421 | Parameters |
||
422 | ---------- |
||
423 | object_ : object |
||
424 | The object to pickle. If None, only the parameters passed to the |
||
425 | `parameters` argument will be saved. |
||
426 | file_ : file |
||
427 | The destination for saving. |
||
428 | parameters : list, optional |
||
429 | Shared variables whose internal numpy arrays should be saved |
||
430 | separately in the `_parameters` field of the tar file. |
||
431 | to_add : dict of objects |
||
432 | A {'name': object} dictionnary of additional objects to save in |
||
433 | the tar archive. Its keys will be used as name in the tar file. |
||
434 | use_cpickle : bool |
||
435 | Use cPickle instead of pickle. Setting it to true will disable the |
||
436 | warning message if you try to pickle objects from the main module, |
||
437 | so be sure that there is no warning before turning this flag |
||
438 | on. Default: False. |
||
439 | protocol : int, optional |
||
440 | The pickling protocol to use. Unlike Python's built-in pickle, the |
||
441 | default is set to `2` instead of 0 for Python 2. The Python 3 |
||
442 | default (level 3) is maintained. |
||
443 | \*\*kwargs |
||
444 | Keyword arguments to be passed to `pickle.Pickler`. |
||
445 | |||
446 | """ |
||
447 | dump(object_, file_, parameters=parameters, use_cpickle=use_cpickle, |
||
448 | protocol=protocol, **kwargs) |
||
449 | if to_add is not None: |
||
450 | for name, obj in six.iteritems(to_add): |
||
451 | add_to_dump(obj, file_, name, parameters=parameters, |
||
452 | use_cpickle=use_cpickle, protocol=protocol, **kwargs) |
||
453 | |||
454 | |||
455 | class _PicklerWithWarning(_Pickler): |
||
456 | """Pickler that adds a warning message. |
||
457 | |||
458 | Adds a warning message if we try to save an object referenced in the |
||
459 | main module. |
||
460 | |||
461 | """ |
||
462 | dispatch = _Pickler.dispatch.copy() |
||
463 | |||
464 | def save_global(self, obj, name=None, **kwargs): |
||
465 | module = getattr(obj, '__module__', None) |
||
466 | if module == '__main__': |
||
467 | warnings.warn( |
||
468 | MAIN_MODULE_WARNING.format(kwargs.get('name', obj.__name__)) |
||
469 | ) |
||
470 | _Pickler.save_global(self, obj, name=name, **kwargs) |
||
471 | |||
472 | dispatch[six.types.FunctionType] = save_global |
||
473 | if six.PY2: |
||
474 | dispatch[six.types.ClassType] = save_global |
||
475 | dispatch[six.types.BuiltinFunctionType] = save_global |
||
476 | dispatch[six.types.TypeType] = save_global |
||
477 | |||
478 | |||
479 | class _SaveObject(object): |
||
480 | r"""Saves an object using Persistent ID. |
||
481 | |||
482 | Parameters |
||
483 | ---------- |
||
484 | pickler : object |
||
485 | The pickler to use |
||
486 | object_ : object |
||
487 | The object to pickle. |
||
488 | external_objects : dict of object |
||
489 | The external objects to save using persistent id. |
||
490 | protocol : int, optional |
||
491 | The pickling protocol to use. |
||
492 | \*\*kwargs |
||
493 | Keyword arguments to be passed to `pickle.Pickler`. |
||
494 | |||
495 | """ |
||
496 | def __init__(self, pickler, object_, external_objects, protocol, **kwargs): |
||
497 | self.pickler = pickler |
||
498 | self.object_ = object_ |
||
499 | self.external_objects = external_objects |
||
500 | self.protocol = protocol |
||
501 | self.kwargs = kwargs |
||
502 | |||
503 | def __call__(self, f): |
||
504 | p = self.pickler(f, protocol=self.protocol, **self.kwargs) |
||
505 | p.persistent_id = _PersistentID(self.external_objects) |
||
506 | p.dump(self.object_) |
||
507 | |||
508 | |||
509 | class _Renamer(object): |
||
510 | """Returns a new name for the given parameter. |
||
511 | |||
512 | It maintains a list of names already used to avoid naming |
||
513 | collisions. It also provides names for variables without |
||
514 | names. |
||
515 | |||
516 | Attributes |
||
517 | ---------- |
||
518 | used_names : set |
||
519 | The set of names already taken. |
||
520 | default_name : str |
||
521 | The name to use if a parameter doesn't have a name. Default: |
||
522 | 'parameter'. |
||
523 | |||
524 | """ |
||
525 | def __init__(self): |
||
526 | self.used_names = set() |
||
527 | self.default_name = 'parameter' |
||
528 | |||
529 | def __call__(self, parameter): |
||
530 | # Standard Blocks parameter |
||
531 | if get_brick(parameter) is not None: |
||
532 | name = get_brick(parameter).get_hierarchical_name( |
||
533 | parameter, SERIALIZATION_BRICK_DELIMITER) |
||
534 | # Shared variables with tag.name |
||
535 | elif hasattr(parameter.tag, 'name'): |
||
536 | name = parameter.tag.name |
||
537 | # Standard shared variable |
||
538 | elif parameter.name is not None: |
||
539 | name = parameter.name |
||
540 | # Variables without names |
||
541 | else: |
||
542 | name = self.default_name |
||
543 | # Handle naming collisions |
||
544 | if name in self.used_names: |
||
545 | i = 2 |
||
546 | new_name = '_'.join([name, str(i)]) |
||
547 | while new_name in self.used_names: |
||
548 | i += 1 |
||
549 | new_name = '_'.join([name, str(i)]) |
||
550 | name = new_name |
||
551 | self.used_names.add(name) |
||
552 | return name |
||
553 | |||
554 | |||
555 | def _recreate_numpy_ndarray(_, content): |
||
556 | return numpy.array(content) |
||
557 | |||
558 | |||
559 | def _recreate_cuda_ndarray(_, content): |
||
560 | return cuda_ndarray.cuda_ndarray.CudaNdarray(content) |
||
561 | |||
562 | |||
563 | def _recreate_pygpu_array(context_name, content): |
||
564 | context = theano.gpuarray.get_context(context_name) |
||
565 | return pygpu.gpuarray.array(content, context=context) |
||
566 | |||
567 | _ARRAY_TYPE_MAP = {numpy.ndarray: 'numpy_ndarray'} |
||
568 | _INVERSE_ARRAY_TYPE_MAP = {'numpy_ndarray': _recreate_numpy_ndarray} |
||
569 | if cuda_ndarray: |
||
570 | _ARRAY_TYPE_MAP[cuda_ndarray.cuda_ndarray.CudaNdarray] = 'cuda_ndarray' |
||
571 | _INVERSE_ARRAY_TYPE_MAP['cuda_ndarray'] = _recreate_cuda_ndarray |
||
572 | if pygpu: |
||
573 | _ARRAY_TYPE_MAP[pygpu.gpuarray.GpuArray] = 'gpuarray' |
||
574 | _INVERSE_ARRAY_TYPE_MAP['gpuarray'] = _recreate_pygpu_array |
||
575 | |||
576 | |||
577 | class _PersistentID(object): |
||
578 | """Returns persistent identifiers for objects saved separately.""" |
||
579 | def __init__(self, external_objects): |
||
580 | self.external_objects = external_objects |
||
581 | |||
582 | def __call__(self, object_): |
||
583 | return self.external_objects.get(id(object_)) |
||
584 | |||
585 | |||
586 | class _PersistentLoad(object): |
||
587 | """Loads object saved using a PersistentID mechanism.""" |
||
588 | def __init__(self, tar_file): |
||
589 | self.tar_file = tar_file |
||
590 | if '_parameters' in tar_file.getnames(): |
||
591 | self.parameters = numpy.load( |
||
592 | tar_file.extractfile(tar_file.getmember('_parameters'))) |
||
593 | self._cache = {} |
||
594 | |||
595 | def __call__(self, id_): |
||
596 | # As we empirically found out, this method can be called multiple |
||
597 | # times with the same id_. That's why we need a cache here to |
||
598 | # avoid creating the same object more than once. |
||
599 | if id_ not in self._cache: |
||
600 | components = _unmangle_parameter_name(id_) |
||
601 | self._cache[id_] = components[0]( |
||
602 | components[1], self.parameters[components[2]]) |
||
603 | return self._cache[id_] |
||
604 | |||
605 | |||
606 | def _mangle_parameter_name(parameter, name): |
||
607 | array_type = type(parameter.container.storage[0]) |
||
608 | context_name = (parameter.context_name |
||
609 | if pygpu and |
||
610 | isinstance(parameter, pygpu.gpuarray.GpuArray) |
||
611 | else None) |
||
612 | if isinstance(context_name, str) and '.' in context_name: |
||
613 | raise ValueError("context name must not contain dots") |
||
614 | return '#1{}.{}.{}'.format( |
||
615 | _ARRAY_TYPE_MAP[array_type], context_name, name) |
||
616 | |||
617 | |||
618 | def _unmangle_parameter_name(mangled_name): |
||
619 | if not isinstance(mangled_name, str): |
||
620 | # This fixes an issue with protocol 0 on Python 3 where |
||
621 | # 'mangled_name' is a bytes object, for some reason. |
||
622 | mangled_name = mangled_name.decode('utf8') |
||
623 | if mangled_name.startswith('#1'): |
||
624 | type_, context_name, name = mangled_name[2:].split('.', 2) |
||
625 | if context_name == 'None': |
||
626 | context_name = None |
||
627 | elif mangled_name.startswith('#'): |
||
628 | # Backward compatibility |
||
629 | type_, name = mangled_name[1:].split('.', 1) |
||
630 | context_name = None |
||
631 | else: |
||
632 | raise ValueError("Do not recognize the mangled parameter name") |
||
633 | return _INVERSE_ARRAY_TYPE_MAP[type_], context_name, name |
||
634 | |||
635 | |||
636 | def _taradd(func, tar_file, name): |
||
637 | """Adds elements dumped by the function `func` to a tar_file. |
||
638 | |||
639 | This functions first calls the function `func` and add the file that |
||
640 | `func` dumps to the achive `tar_file`, under the name `name`. |
||
641 | |||
642 | Parameters |
||
643 | ---------- |
||
644 | func : function |
||
645 | The dumping function. |
||
646 | tar_file : file |
||
647 | The archive that we are filling. |
||
648 | name : str |
||
649 | The name of the dumped file in the archive. |
||
650 | |||
651 | """ |
||
652 | with tempfile.NamedTemporaryFile('wb', delete=False) as temp_file: |
||
653 | func(temp_file) |
||
654 | temp_file.close() |
||
655 | tar_file.add(temp_file.name, arcname=name) |
||
656 | if os.path.isfile(temp_file.name): |
||
657 | os.remove(temp_file.name) |
||
658 | |||
659 | |||
660 | def _load_parameters_npzfile(file_): |
||
661 | """Loads parameters from a .npz file in a tar archive.""" |
||
662 | with tarfile.open(fileobj=file_, mode='r') as tar_file: |
||
663 | return numpy.load( |
||
664 | tar_file.extractfile(tar_file.getmember('_parameters'))) |
||
665 |
This check looks for invalid names for a range of different identifiers.
You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.
If your project includes a Pylint configuration file, the settings contained in that file take precedence.
To find out more about Pylint, please refer to their site.