1 | """Blocks native serialization - tar files with pickles and numpy arrays. |
||
2 | |||
3 | This module provides :func:`load` and :func:`dump` functions that can serve |
||
4 | as drop-in replacement for the respective functions from the standard |
||
5 | :mod:`pickle` module. The main differences between them and the standard |
||
6 | ones are: |
||
7 | |||
8 | - The dump is physically a tarball, in which the pickle is stored |
||
9 | as '_pkl' file. |
||
10 | |||
11 | - A special file '_parameters' in the tarball can contain the data |
||
12 | of a selected set of Theano shared variables. This data is |
||
13 | referenced from `_pkl` using persistent id mechanism, which means |
||
14 | that no duplication takes place. The goal here is to save the values |
||
15 | of the parameters (this is what these shared variables are in most |
||
16 | cases) in the most robust way possible. The actual format for |
||
17 | '_parameters' file is the one used by :func:`numpy.savez`, i.e. a zip |
||
18 | file of numpy arrays. |
||
19 | |||
20 | - More objects can be dumped in the archive using the `add_to_dump` |
||
21 | function. If the object has the same parameters as the one already |
||
22 | dumped, then you can avoid to dump those parameters thank to the |
||
23 | persistent id mechanism. |
||
24 | |||
25 | - The :func:`dump` strives to catch situations when the user tries |
||
26 | to pickle a function or a class not defined in the global namespace |
||
27 | and give a meaningful warning. |
||
28 | |||
29 | If briefly, this module proposes a dumping mechanism which allows for |
||
30 | greater robustness and persistence than standard pickling. |
||
31 | |||
32 | Examples |
||
33 | -------- |
||
34 | Consider a standard main loop (without an algorithm and a data stream |
||
35 | for brevity) |
||
36 | |||
37 | >>> from theano import tensor |
||
38 | >>> from blocks.main_loop import MainLoop |
||
39 | >>> from blocks.bricks import MLP, Tanh, Softmax |
||
40 | >>> from blocks.model import Model |
||
41 | >>> mlp = MLP([Tanh(), None], [784, 10, 10]) |
||
42 | >>> x = tensor.matrix('features') |
||
43 | >>> y = tensor.lmatrix('targets') |
||
44 | >>> cost = Softmax().categorical_cross_entropy( |
||
45 | ... y.flatten(), mlp.apply(tensor.flatten(x, outdim=2))) |
||
46 | >>> main_loop = MainLoop(None, None, model=Model(cost)) |
||
47 | |||
48 | Let's see how the main loop is dumped by :func:`dump` |
||
49 | |||
50 | >>> from blocks.serialization import dump, load |
||
51 | >>> import tarfile |
||
52 | >>> with open('main_loop.tar', 'wb') as dst: |
||
53 | ... dump(main_loop, dst) |
||
54 | >>> tarball = tarfile.open('main_loop.tar', 'r') |
||
55 | >>> tarball # doctest: +ELLIPSIS |
||
56 | <tarfile.TarFile object at ...> |
||
57 | >>> tarball.getnames() |
||
58 | ['_pkl'] |
||
59 | >>> tarball.close() |
||
60 | |||
61 | As promised, the dump is a tarball. Since we did not ask for any additional |
||
62 | magic, it just contains the pickled main loop in '_pkl' file. |
||
63 | |||
64 | Let's do something more interesting: |
||
65 | |||
66 | >>> with open('main_loop.tar', 'wb') as dst: |
||
67 | ... dump(main_loop, dst, |
||
68 | ... parameters=main_loop.model.parameters) |
||
69 | >>> tarball = tarfile.open('main_loop.tar', 'r') |
||
70 | >>> tarball.getnames() |
||
71 | ['_parameters', '_pkl'] |
||
72 | |||
73 | As requested by specifying the `_parameters` argument, the parameters were |
||
74 | saved in a zip file. |
||
75 | |||
76 | >>> import numpy |
||
77 | >>> ps = numpy.load(tarball.extractfile(tarball.getmember('_parameters'))) |
||
78 | >>> sorted(ps.keys()) # doctest: +ELLIPSIS |
||
79 | ['|mlp|linear_0.W', '|mlp|linear_0.b', '|mlp|linear_1.W', '|mlp|lin...] |
||
80 | >>> ps.close() |
||
81 | |||
82 | The names for parameters are chosen intelligently to reflect their |
||
83 | position in the brick hierarchy, if they belong to bricks, and by |
||
84 | simply using the `.name` attribute, if they do not. |
||
85 | |||
86 | The loading of the main loop as a whole still works: |
||
87 | |||
88 | >>> with open('main_loop.tar', 'rb') as src: |
||
89 | ... main_loop_loaded = load(src) |
||
90 | >>> main_loop_loaded # doctest: +ELLIPSIS |
||
91 | <blocks.main_loop.MainLoop object at ...> |
||
92 | |||
93 | Additionally, this module provides convenience routine |
||
94 | :func:`load_parameters`: |
||
95 | |||
96 | >>> with open('main_loop.tar', 'rb') as src: |
||
97 | ... parameters = load_parameters(src) |
||
98 | >>> sorted(parameters.keys()) # doctest: +ELLIPSIS |
||
99 | ['/mlp/linear_0.W', '/mlp/linear_0.b', '/mlp/linear_1.W', '/mlp/line...] |
||
100 | |||
101 | Loading parameters saved by :func:`dump` with :func:`load_parameters` |
||
102 | ensures that their hierarchical names are compatible with |
||
103 | :class:`~blocks.model.Model` and :class:`~blocks.select.Selector` classes. |
||
104 | |||
105 | TODO: Add information about :func:`add_to_dump`. |
||
106 | |||
107 | """ |
||
108 | import numpy |
||
109 | import os |
||
110 | import pickle |
||
111 | import shutil |
||
112 | import six |
||
113 | import tarfile |
||
114 | import tempfile |
||
115 | import warnings |
||
116 | import logging |
||
117 | from contextlib import closing |
||
118 | from pickle import HIGHEST_PROTOCOL |
||
119 | try: |
||
120 | from pickle import DEFAULT_PROTOCOL |
||
121 | from pickle import _Pickler |
||
122 | except ImportError: |
||
123 | DEFAULT_PROTOCOL = HIGHEST_PROTOCOL |
||
124 | from pickle import Pickler as _Pickler |
||
125 | |||
126 | from six.moves import cPickle |
||
127 | import theano |
||
128 | try: |
||
129 | from theano.sandbox.cuda import cuda_ndarray |
||
130 | except Exception: |
||
0 ignored issues
–
show
|
|||
131 | cuda_ndarray = None |
||
132 | try: |
||
133 | import pygpu |
||
134 | except Exception: |
||
0 ignored issues
–
show
Catching very general exceptions such as
Exception is usually not recommended.
Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed. So, unless you specifically plan to handle any error, consider adding a more specific exception.
Loading history...
|
|||
135 | pygpu = None |
||
136 | from blocks.config import config |
||
137 | from blocks.filter import get_brick |
||
138 | from blocks.utils import change_recursion_limit |
||
139 | from blocks.bricks.base import BRICK_DELIMITER |
||
140 | |||
141 | |||
142 | logger = logging.getLogger(__name__) |
||
143 | |||
144 | SERIALIZATION_BRICK_DELIMITER = '|' |
||
145 | MAIN_MODULE_WARNING = """WARNING: Main loop depends on the function `{}` in \ |
||
146 | `__main__` namespace. |
||
147 | |||
148 | Because of limitations to pickling, this means that you will not be able to \ |
||
149 | resume your model outside of a namespace containing this function. In other \ |
||
150 | words, you can only call `continue_training` from within this script.""" |
||
151 | |||
152 | |||
153 | def dump(object_, file_, parameters=None, use_cpickle=False, |
||
154 | protocol=DEFAULT_PROTOCOL, **kwargs): |
||
155 | r"""Pickles an object, optionally saving its parameters separately. |
||
156 | |||
157 | Parameters |
||
158 | ---------- |
||
159 | object_ : object |
||
160 | The object to pickle. If None, only the parameters passed to the |
||
161 | `parameters` argument will be saved. |
||
162 | file_ : file |
||
163 | The destination for saving. |
||
164 | parameters : list, optional |
||
165 | Shared variables whose internal numpy arrays should be saved |
||
166 | separately in the `_parameters` field of the tar file. |
||
167 | pickle_object : bool |
||
168 | If False, `object_` will not be serialized, only its parameters. |
||
169 | This flag can be used when `object_` is not serializable, but one |
||
170 | still want to save its parameters. Default: True |
||
171 | use_cpickle : bool |
||
172 | Use cPickle instead of pickle. Setting it to true will disable the |
||
173 | warning message if you try to pickle objects from the main module, |
||
174 | so be sure that there is no warning before turning this flag |
||
175 | on. Default: False. |
||
176 | protocol : int, optional |
||
177 | The pickling protocol to use. Unlike Python's built-in pickle, the |
||
178 | default is set to `2` instead of 0 for Python 2. The Python 3 |
||
179 | default (level 3) is maintained. |
||
180 | \*\*kwargs |
||
181 | Keyword arguments to be passed to `pickle.Pickler`. |
||
182 | |||
183 | """ |
||
184 | if use_cpickle: |
||
185 | pickler = cPickle.Pickler |
||
186 | else: |
||
187 | pickler = _PicklerWithWarning |
||
188 | with closing(tarfile.TarFile(fileobj=file_, mode='w')) as tar_file: |
||
189 | external_objects = {} |
||
190 | |||
191 | def _save_parameters(f): |
||
192 | renamer = _Renamer() |
||
193 | named_parameters = {renamer(p): p for p in parameters} |
||
194 | numpy.savez(f, **{n: p.get_value() |
||
195 | for n, p in named_parameters.items()}) |
||
196 | for name, p in named_parameters.items(): |
||
197 | array_ = p.container.storage[0] |
||
198 | external_objects[id(array_)] = _mangle_parameter_name(p, name) |
||
199 | if parameters: |
||
200 | _taradd(_save_parameters, tar_file, '_parameters') |
||
201 | if object_ is not None: |
||
202 | save_object = _SaveObject(pickler, object_, external_objects, |
||
203 | protocol, **kwargs) |
||
204 | _taradd(save_object, tar_file, '_pkl') |
||
205 | |||
206 | |||
207 | def secure_dump(object_, path, dump_function=dump, **kwargs): |
||
208 | r"""Robust serialization - does not corrupt your files when failed. |
||
209 | |||
210 | Parameters |
||
211 | ---------- |
||
212 | object_ : object |
||
213 | The object to be saved to the disk. |
||
214 | path : str |
||
215 | The destination for saving. |
||
216 | dump_function : function |
||
217 | The function that is used to perform the serialization. Must take |
||
218 | an object and file object as arguments. By default, :func:`dump` is |
||
219 | used. An alternative would be :func:`pickle.dump`. |
||
220 | \*\*kwargs |
||
221 | Keyword arguments to be passed to `dump_function`. |
||
222 | |||
223 | """ |
||
224 | try: |
||
225 | logger.debug("Dumping object to a temporary file") |
||
226 | with tempfile.NamedTemporaryFile(delete=False, |
||
227 | dir=config.temp_dir) as temp: |
||
228 | dump_function(object_, temp, **kwargs) |
||
229 | logger.debug("Moving the temporary file") |
||
230 | shutil.move(temp.name, path) |
||
231 | logger.debug("Dump finished") |
||
232 | except: |
||
233 | if "temp" in locals(): |
||
234 | os.remove(temp.name) |
||
235 | raise |
||
236 | |||
237 | |||
238 | def load(file_, name='_pkl', use_cpickle=False, **kwargs): |
||
239 | r"""Loads an object saved using the `dump` function. |
||
240 | |||
241 | By default, this function loads the object saved by the `dump` |
||
242 | function. If some objects have been added to the archive using the |
||
243 | `add_to_dump` function, then you can load them by passing their name |
||
244 | to the `name` parameter. |
||
245 | |||
246 | Parameters |
||
247 | ---------- |
||
248 | file_ : file |
||
249 | The file that contains the object to load. |
||
250 | name : str |
||
251 | Name of the object to load. Default is `_pkl`, meaning that it is |
||
252 | the original object which have been dumped that is loaded. |
||
253 | use_cpickle : bool |
||
254 | Use cPickle instead of pickle. Default: False. |
||
255 | \*\*kwargs |
||
256 | Keyword arguments to be passed to `pickle.Unpickler`. |
||
257 | Used for e.g. specifying the encoding so as to load legacy Python |
||
258 | pickles under Python 3.x. |
||
259 | |||
260 | Returns |
||
261 | ------- |
||
262 | The object saved in ``file_``. |
||
263 | |||
264 | """ |
||
265 | file_.seek(0) # To be able to read several objects in one file |
||
266 | if use_cpickle: |
||
267 | unpickler = cPickle.Unpickler |
||
268 | else: |
||
269 | unpickler = pickle.Unpickler |
||
270 | with tarfile.open(fileobj=file_, mode='r') as tar_file: |
||
271 | p = unpickler( |
||
272 | tar_file.extractfile(tar_file.getmember(name)), |
||
273 | **kwargs |
||
274 | ) |
||
275 | if '_parameters' in tar_file.getnames(): |
||
276 | p.persistent_load = _PersistentLoad(tar_file) |
||
277 | return p.load() |
||
278 | |||
279 | |||
280 | def load_parameters(file_): |
||
281 | """Loads the parameter values saved by :func:`dump`. |
||
282 | |||
283 | This functions loads the parameters that have been saved separately by |
||
284 | :func:`dump`, ie the ones given to its parameter `parameters`. |
||
285 | |||
286 | Parameters |
||
287 | ---------- |
||
288 | file_ : file |
||
289 | The source to load the parameters from. |
||
290 | |||
291 | Returns |
||
292 | ------- |
||
293 | A dictionary of (parameter name, numpy array) pairs. |
||
294 | |||
295 | """ |
||
296 | with closing(_load_parameters_npzfile(file_)) as npz_file: |
||
297 | return {name.replace(SERIALIZATION_BRICK_DELIMITER, |
||
298 | BRICK_DELIMITER): value |
||
299 | for name, value in npz_file.items()} |
||
300 | |||
301 | |||
302 | def add_to_dump(object_, file_, name, parameters=None, use_cpickle=False, |
||
303 | protocol=DEFAULT_PROTOCOL, **kwargs): |
||
304 | r"""Pickles an object to an existing tar archive. |
||
305 | |||
306 | This function allows to dump more objects to an existing archive. If |
||
307 | the object you want to dump posesses the same set of shared variables |
||
308 | as the object already dumped, you can pass them to the `parameters` |
||
309 | argument, which will avoid them to be serialized a second time. |
||
310 | However, it won't work if the shared variable you pass to the |
||
311 | `parameters` argument are not already in the archive. |
||
312 | |||
313 | Parameters |
||
314 | ---------- |
||
315 | object_ : object |
||
316 | The object to pickle. |
||
317 | file_ : file |
||
318 | The destination for saving, opened in read-write mode (`r+`). |
||
319 | name : str |
||
320 | The name of the object you are dumping. It will be used as a file |
||
321 | name in the archive. '_pkl' and '_paramters' are reserved names |
||
322 | and can't be used. |
||
323 | parameters : list, optional |
||
324 | Shared variables whose internal numpy arrays should be saved |
||
325 | separately in the `_parameters` field of the tar file. Must be a |
||
326 | subset of the parameters already in the archive. |
||
327 | use_cpickle : bool |
||
328 | Use cPickle instead of pickle. Setting it to true will disable the |
||
329 | warning message if you try to pickle objects from the main module! |
||
330 | Be sure that you don't have the warning before turning this flag |
||
331 | on. Default: False. |
||
332 | protocol : int, optional |
||
333 | The pickling protocol to use. Unlike Python's built-in pickle, the |
||
334 | default is set to `2` instead of 0 for Python 2. The Python 3 |
||
335 | default (level 3) is maintained. |
||
336 | \*\*kwargs |
||
337 | Keyword arguments to be passed to `pickle.Pickler`. |
||
338 | |||
339 | """ |
||
340 | if name in ['_pkl', '_parameters']: |
||
341 | raise ValueError("_pkl and _parameters are reserved names and can't" |
||
342 | " be used as name for your object.") |
||
343 | |||
344 | external_parameters = {} |
||
345 | if parameters is not None: |
||
346 | renamer = _Renamer() |
||
347 | named_parameters = {renamer(p): p for p in parameters} |
||
348 | for n, p in named_parameters.items(): |
||
349 | array_ = p.container.storage[0] |
||
350 | external_parameters[id(array_)] = _mangle_parameter_name(p, n) |
||
351 | |||
352 | # Check that the parameters are the same that the ones in the archive. |
||
353 | file_.seek(0) # To be able to read what is in the tar file already. |
||
354 | with closing(tarfile.TarFile(fileobj=file_, mode='r')) as tar_file: |
||
355 | if '_parameters' not in tar_file.getnames(): |
||
356 | raise ValueError("There is no parameters in the archive, so" |
||
357 | " you can't use the argument parameters.") |
||
358 | else: |
||
359 | parameters = numpy.load( |
||
360 | tar_file.extractfile(tar_file.getmember('_parameters'))) |
||
361 | s1 = set(parameters.keys()) |
||
362 | s2 = [_unmangle_parameter_name(x)[2] for x in |
||
363 | external_parameters.values()] |
||
364 | if not s1.issuperset(s2): |
||
365 | raise ValueError('The set of parameters is different' |
||
366 | ' from the one in the archive.') |
||
367 | |||
368 | if use_cpickle: |
||
369 | pickler = cPickle.Pickler |
||
370 | else: |
||
371 | pickler = _PicklerWithWarning |
||
372 | file_.seek(0) # To be able to add new things in the tar file. |
||
373 | with closing(tarfile.TarFile(fileobj=file_, mode='a')) as tar_file: |
||
374 | save_object = _SaveObject(pickler, object_, external_parameters, |
||
375 | protocol, **kwargs) |
||
376 | _taradd(save_object, tar_file, name) |
||
377 | |||
378 | |||
379 | def continue_training(path): |
||
380 | """Continues training using checkpoint. |
||
381 | |||
382 | Parameters |
||
383 | ---------- |
||
384 | path : str |
||
385 | Path to checkpoint. |
||
386 | |||
387 | Notes |
||
388 | ----- |
||
389 | Python picklers can unpickle objects from global namespace only if |
||
390 | they are present in namespace where unpickling happens. Often global |
||
391 | functions are needed for mapping, filtering and other data stream |
||
392 | operations. In a case if the main loop uses global objects and |
||
393 | this function fails with a message like |
||
394 | ``` |
||
395 | AttributeError: 'module' object has no attribute '...' |
||
396 | ``` |
||
397 | it means that you need to import these objects. |
||
398 | |||
399 | Examples |
||
400 | -------- |
||
401 | This function can be used in two ways: in your script where a main |
||
402 | loop defined or in a different script. For later options see Notes |
||
403 | section. |
||
404 | |||
405 | """ |
||
406 | with change_recursion_limit(config.recursion_limit): |
||
407 | with open(path, "rb") as f: |
||
408 | main_loop = load(f) |
||
409 | main_loop.run() |
||
410 | |||
411 | |||
412 | def dump_and_add_to_dump(object_, file_, parameters=None, to_add=None, |
||
413 | use_cpickle=False, protocol=DEFAULT_PROTOCOL, |
||
414 | **kwargs): |
||
415 | r"""Calls both `dump` and `add_to_dump` to serialze several objects. |
||
416 | |||
417 | This function is used to serialize several at the same time, using |
||
418 | persistent ID. Its main advantage is that it can be used with |
||
419 | `secure_dump`. |
||
420 | |||
421 | Parameters |
||
422 | ---------- |
||
423 | object_ : object |
||
424 | The object to pickle. If None, only the parameters passed to the |
||
425 | `parameters` argument will be saved. |
||
426 | file_ : file |
||
427 | The destination for saving. |
||
428 | parameters : list, optional |
||
429 | Shared variables whose internal numpy arrays should be saved |
||
430 | separately in the `_parameters` field of the tar file. |
||
431 | to_add : dict of objects |
||
432 | A {'name': object} dictionnary of additional objects to save in |
||
433 | the tar archive. Its keys will be used as name in the tar file. |
||
434 | use_cpickle : bool |
||
435 | Use cPickle instead of pickle. Setting it to true will disable the |
||
436 | warning message if you try to pickle objects from the main module, |
||
437 | so be sure that there is no warning before turning this flag |
||
438 | on. Default: False. |
||
439 | protocol : int, optional |
||
440 | The pickling protocol to use. Unlike Python's built-in pickle, the |
||
441 | default is set to `2` instead of 0 for Python 2. The Python 3 |
||
442 | default (level 3) is maintained. |
||
443 | \*\*kwargs |
||
444 | Keyword arguments to be passed to `pickle.Pickler`. |
||
445 | |||
446 | """ |
||
447 | dump(object_, file_, parameters=parameters, use_cpickle=use_cpickle, |
||
448 | protocol=protocol, **kwargs) |
||
449 | if to_add is not None: |
||
450 | for name, obj in six.iteritems(to_add): |
||
451 | add_to_dump(obj, file_, name, parameters=parameters, |
||
452 | use_cpickle=use_cpickle, protocol=protocol, **kwargs) |
||
453 | |||
454 | |||
455 | class _PicklerWithWarning(_Pickler): |
||
456 | """Pickler that adds a warning message. |
||
457 | |||
458 | Adds a warning message if we try to save an object referenced in the |
||
459 | main module. |
||
460 | |||
461 | """ |
||
462 | dispatch = _Pickler.dispatch.copy() |
||
463 | |||
464 | def save_global(self, obj, name=None, **kwargs): |
||
465 | module = getattr(obj, '__module__', None) |
||
466 | if module == '__main__': |
||
467 | warnings.warn( |
||
468 | MAIN_MODULE_WARNING.format(kwargs.get('name', obj.__name__)) |
||
469 | ) |
||
470 | _Pickler.save_global(self, obj, name=name, **kwargs) |
||
471 | |||
472 | dispatch[six.types.FunctionType] = save_global |
||
473 | if six.PY2: |
||
474 | dispatch[six.types.ClassType] = save_global |
||
475 | dispatch[six.types.BuiltinFunctionType] = save_global |
||
476 | dispatch[six.types.TypeType] = save_global |
||
477 | |||
478 | |||
479 | class _SaveObject(object): |
||
480 | r"""Saves an object using Persistent ID. |
||
481 | |||
482 | Parameters |
||
483 | ---------- |
||
484 | pickler : object |
||
485 | The pickler to use |
||
486 | object_ : object |
||
487 | The object to pickle. |
||
488 | external_objects : dict of object |
||
489 | The external objects to save using persistent id. |
||
490 | protocol : int, optional |
||
491 | The pickling protocol to use. |
||
492 | \*\*kwargs |
||
493 | Keyword arguments to be passed to `pickle.Pickler`. |
||
494 | |||
495 | """ |
||
496 | def __init__(self, pickler, object_, external_objects, protocol, **kwargs): |
||
497 | self.pickler = pickler |
||
498 | self.object_ = object_ |
||
499 | self.external_objects = external_objects |
||
500 | self.protocol = protocol |
||
501 | self.kwargs = kwargs |
||
502 | |||
503 | def __call__(self, f): |
||
504 | p = self.pickler(f, protocol=self.protocol, **self.kwargs) |
||
505 | p.persistent_id = _PersistentID(self.external_objects) |
||
506 | p.dump(self.object_) |
||
507 | |||
508 | |||
509 | class _Renamer(object): |
||
510 | """Returns a new name for the given parameter. |
||
511 | |||
512 | It maintains a list of names already used to avoid naming |
||
513 | collisions. It also provides names for variables without |
||
514 | names. |
||
515 | |||
516 | Attributes |
||
517 | ---------- |
||
518 | used_names : set |
||
519 | The set of names already taken. |
||
520 | default_name : str |
||
521 | The name to use if a parameter doesn't have a name. Default: |
||
522 | 'parameter'. |
||
523 | |||
524 | """ |
||
525 | def __init__(self): |
||
526 | self.used_names = set() |
||
527 | self.default_name = 'parameter' |
||
528 | |||
529 | def __call__(self, parameter): |
||
530 | # Standard Blocks parameter |
||
531 | if get_brick(parameter) is not None: |
||
532 | name = get_brick(parameter).get_hierarchical_name( |
||
533 | parameter, SERIALIZATION_BRICK_DELIMITER) |
||
534 | # Shared variables with tag.name |
||
535 | elif hasattr(parameter.tag, 'name'): |
||
536 | name = parameter.tag.name |
||
537 | # Standard shared variable |
||
538 | elif parameter.name is not None: |
||
539 | name = parameter.name |
||
540 | # Variables without names |
||
541 | else: |
||
542 | name = self.default_name |
||
543 | # Handle naming collisions |
||
544 | if name in self.used_names: |
||
545 | i = 2 |
||
546 | new_name = '_'.join([name, str(i)]) |
||
547 | while new_name in self.used_names: |
||
548 | i += 1 |
||
549 | new_name = '_'.join([name, str(i)]) |
||
550 | name = new_name |
||
551 | self.used_names.add(name) |
||
552 | return name |
||
553 | |||
554 | |||
555 | def _recreate_numpy_ndarray(_, content): |
||
556 | return numpy.array(content) |
||
557 | |||
558 | |||
559 | def _recreate_cuda_ndarray(_, content): |
||
560 | return cuda_ndarray.cuda_ndarray.CudaNdarray(content) |
||
561 | |||
562 | |||
563 | def _recreate_pygpu_array(context_name, content): |
||
564 | context = theano.gpuarray.get_context(context_name) |
||
565 | return pygpu.gpuarray.array(content, context=context) |
||
566 | |||
567 | _ARRAY_TYPE_MAP = {numpy.ndarray: 'numpy_ndarray'} |
||
568 | _INVERSE_ARRAY_TYPE_MAP = {'numpy_ndarray': _recreate_numpy_ndarray} |
||
569 | if cuda_ndarray: |
||
570 | _ARRAY_TYPE_MAP[cuda_ndarray.cuda_ndarray.CudaNdarray] = 'cuda_ndarray' |
||
571 | _INVERSE_ARRAY_TYPE_MAP['cuda_ndarray'] = _recreate_cuda_ndarray |
||
572 | if pygpu: |
||
573 | _ARRAY_TYPE_MAP[pygpu.gpuarray.GpuArray] = 'gpuarray' |
||
574 | _INVERSE_ARRAY_TYPE_MAP['gpuarray'] = _recreate_pygpu_array |
||
575 | |||
576 | |||
577 | class _PersistentID(object): |
||
578 | """Returns persistent identifiers for objects saved separately.""" |
||
579 | def __init__(self, external_objects): |
||
580 | self.external_objects = external_objects |
||
581 | |||
582 | def __call__(self, object_): |
||
583 | return self.external_objects.get(id(object_)) |
||
584 | |||
585 | |||
586 | class _PersistentLoad(object): |
||
587 | """Loads object saved using a PersistentID mechanism.""" |
||
588 | def __init__(self, tar_file): |
||
589 | self.tar_file = tar_file |
||
590 | if '_parameters' in tar_file.getnames(): |
||
591 | self.parameters = numpy.load( |
||
592 | tar_file.extractfile(tar_file.getmember('_parameters'))) |
||
593 | self._cache = {} |
||
594 | |||
595 | def __call__(self, id_): |
||
596 | # As we empirically found out, this method can be called multiple |
||
597 | # times with the same id_. That's why we need a cache here to |
||
598 | # avoid creating the same object more than once. |
||
599 | if id_ not in self._cache: |
||
600 | components = _unmangle_parameter_name(id_) |
||
601 | self._cache[id_] = components[0]( |
||
602 | components[1], self.parameters[components[2]]) |
||
603 | return self._cache[id_] |
||
604 | |||
605 | |||
606 | def _mangle_parameter_name(parameter, name): |
||
607 | array_type = type(parameter.container.storage[0]) |
||
608 | context_name = (parameter.context_name |
||
609 | if pygpu and |
||
610 | isinstance(parameter, pygpu.gpuarray.GpuArray) |
||
611 | else None) |
||
612 | if isinstance(context_name, str) and '.' in context_name: |
||
613 | raise ValueError("context name must not contain dots") |
||
614 | return '#1{}.{}.{}'.format( |
||
615 | _ARRAY_TYPE_MAP[array_type], context_name, name) |
||
616 | |||
617 | |||
618 | def _unmangle_parameter_name(mangled_name): |
||
619 | if not isinstance(mangled_name, str): |
||
620 | # This fixes an issue with protocol 0 on Python 3 where |
||
621 | # 'mangled_name' is a bytes object, for some reason. |
||
622 | mangled_name = mangled_name.decode('utf8') |
||
623 | if mangled_name.startswith('#1'): |
||
624 | type_, context_name, name = mangled_name[2:].split('.', 2) |
||
625 | if context_name == 'None': |
||
626 | context_name = None |
||
627 | elif mangled_name.startswith('#'): |
||
628 | # Backward compatibility |
||
629 | type_, name = mangled_name[1:].split('.', 1) |
||
630 | context_name = None |
||
631 | else: |
||
632 | raise ValueError("Do not recognize the mangled parameter name") |
||
633 | return _INVERSE_ARRAY_TYPE_MAP[type_], context_name, name |
||
634 | |||
635 | |||
636 | def _taradd(func, tar_file, name): |
||
637 | """Adds elements dumped by the function `func` to a tar_file. |
||
638 | |||
639 | This functions first calls the function `func` and add the file that |
||
640 | `func` dumps to the achive `tar_file`, under the name `name`. |
||
641 | |||
642 | Parameters |
||
643 | ---------- |
||
644 | func : function |
||
645 | The dumping function. |
||
646 | tar_file : file |
||
647 | The archive that we are filling. |
||
648 | name : str |
||
649 | The name of the dumped file in the archive. |
||
650 | |||
651 | """ |
||
652 | with tempfile.NamedTemporaryFile('wb', delete=False) as temp_file: |
||
653 | func(temp_file) |
||
654 | temp_file.close() |
||
655 | tar_file.add(temp_file.name, arcname=name) |
||
656 | if os.path.isfile(temp_file.name): |
||
657 | os.remove(temp_file.name) |
||
658 | |||
659 | |||
660 | def _load_parameters_npzfile(file_): |
||
661 | """Loads parameters from a .npz file in a tar archive.""" |
||
662 | with tarfile.open(fileobj=file_, mode='r') as tar_file: |
||
663 | return numpy.load( |
||
664 | tar_file.extractfile(tar_file.getmember('_parameters'))) |
||
665 |
Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.
So, unless you specifically plan to handle any error, consider adding a more specific exception.