add_to_dump()   F
last analyzed

Complexity

Conditions 11

Size

Total Lines 75

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 11
dl 0
loc 75
rs 3.375
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like add_to_dump() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Blocks native serialization - tar files with pickles and numpy arrays.
2
3
This module provides :func:`load` and :func:`dump` functions that can serve
4
as drop-in replacement for the respective functions from the standard
5
:mod:`pickle` module. The main differences between them and the standard
6
ones are:
7
8
    - The dump is physically a tarball, in which the pickle is stored
9
      as '_pkl' file.
10
11
    - A special file '_parameters' in the tarball can contain the data
12
      of a selected set of Theano shared variables. This data is
13
      referenced from `_pkl` using persistent id mechanism, which means
14
      that no duplication takes place. The goal here is to save the values
15
      of the parameters (this is what these shared variables are in most
16
      cases) in the most robust way possible. The actual format for
17
      '_parameters' file is the one used by :func:`numpy.savez`, i.e. a zip
18
      file of numpy arrays.
19
20
    - More objects can be dumped in the archive using the `add_to_dump`
21
      function. If the object has the same parameters as the one already
22
      dumped, then you can avoid to dump those parameters thank to the
23
      persistent id mechanism.
24
25
    - The :func:`dump` strives to catch situations when the user tries
26
      to pickle a function or a class not defined in the global namespace
27
      and give a meaningful warning.
28
29
If briefly, this module proposes a dumping mechanism which allows for
30
greater robustness and persistence than standard pickling.
31
32
Examples
33
--------
34
Consider a standard main loop (without an algorithm and a data stream
35
for brevity)
36
37
>>> from theano import tensor
38
>>> from blocks.main_loop import MainLoop
39
>>> from blocks.bricks import MLP, Tanh, Softmax
40
>>> from blocks.model import Model
41
>>> mlp = MLP([Tanh(), None], [784, 10, 10])
42
>>> x = tensor.matrix('features')
43
>>> y = tensor.lmatrix('targets')
44
>>> cost = Softmax().categorical_cross_entropy(
45
...            y.flatten(), mlp.apply(tensor.flatten(x, outdim=2)))
46
>>> main_loop = MainLoop(None, None, model=Model(cost))
47
48
Let's see how the main loop is dumped by :func:`dump`
49
50
>>> from blocks.serialization import dump, load
51
>>> import tarfile
52
>>> with open('main_loop.tar', 'wb') as dst:
53
...     dump(main_loop, dst)
54
>>> tarball = tarfile.open('main_loop.tar', 'r')
55
>>> tarball # doctest: +ELLIPSIS
56
<tarfile.TarFile object at ...>
57
>>> tarball.getnames()
58
['_pkl']
59
>>> tarball.close()
60
61
As promised, the dump is a tarball. Since we did not ask for any additional
62
magic, it just contains the pickled main loop in '_pkl' file.
63
64
Let's do something more interesting:
65
66
>>> with open('main_loop.tar', 'wb') as dst:
67
...     dump(main_loop, dst,
68
...          parameters=main_loop.model.parameters)
69
>>> tarball = tarfile.open('main_loop.tar', 'r')
70
>>> tarball.getnames()
71
['_parameters', '_pkl']
72
73
As requested by specifying the `_parameters` argument, the parameters were
74
saved in a zip file.
75
76
>>> import numpy
77
>>> ps = numpy.load(tarball.extractfile(tarball.getmember('_parameters')))
78
>>> sorted(ps.keys()) # doctest: +ELLIPSIS
79
['|mlp|linear_0.W', '|mlp|linear_0.b', '|mlp|linear_1.W', '|mlp|lin...]
80
>>> ps.close()
81
82
The names for parameters are chosen intelligently to reflect their
83
position in the brick hierarchy, if they belong to bricks, and by
84
simply using the `.name` attribute, if they do not.
85
86
The loading of the main loop as a whole still works:
87
88
>>> with open('main_loop.tar', 'rb') as src:
89
...     main_loop_loaded = load(src)
90
>>> main_loop_loaded # doctest: +ELLIPSIS
91
<blocks.main_loop.MainLoop object at ...>
92
93
Additionally, this module provides convenience routine
94
:func:`load_parameters`:
95
96
>>> with open('main_loop.tar', 'rb') as src:
97
...     parameters = load_parameters(src)
98
>>> sorted(parameters.keys()) # doctest: +ELLIPSIS
99
['/mlp/linear_0.W', '/mlp/linear_0.b', '/mlp/linear_1.W', '/mlp/line...]
100
101
Loading parameters saved by :func:`dump` with :func:`load_parameters`
102
ensures that their hierarchical names are compatible with
103
:class:`~blocks.model.Model` and :class:`~blocks.select.Selector` classes.
104
105
TODO: Add information about :func:`add_to_dump`.
106
107
"""
108
import numpy
109
import os
110
import pickle
111
import shutil
112
import six
113
import tarfile
114
import tempfile
115
import warnings
116
import logging
117
from contextlib import closing
118
from pickle import HIGHEST_PROTOCOL
119
try:
120
    from pickle import DEFAULT_PROTOCOL
121
    from pickle import _Pickler
122
except ImportError:
123
    DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
124
    from pickle import Pickler as _Pickler
125
126
from six.moves import cPickle
127
import theano
128
try:
129
    from theano.sandbox.cuda import cuda_ndarray
130
except Exception:
0 ignored issues
show
Best Practice introduced by
Catching very general exceptions such as Exception is usually not recommended.

Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.

So, unless you specifically plan to handle any error, consider adding a more specific exception.

Loading history...
131
    cuda_ndarray = None
0 ignored issues
show
Coding Style Naming introduced by
The name cuda_ndarray does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
132
try:
133
    import pygpu
134
except Exception:
0 ignored issues
show
Best Practice introduced by
Catching very general exceptions such as Exception is usually not recommended.

Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.

So, unless you specifically plan to handle any error, consider adding a more specific exception.

Loading history...
135
    pygpu = None
0 ignored issues
show
Coding Style Naming introduced by
The name pygpu does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
136
from blocks.config import config
137
from blocks.filter import get_brick
138
from blocks.utils import change_recursion_limit
139
from blocks.bricks.base import BRICK_DELIMITER
140
141
142
logger = logging.getLogger(__name__)
143
144
SERIALIZATION_BRICK_DELIMITER = '|'
145
MAIN_MODULE_WARNING = """WARNING: Main loop depends on the function `{}` in \
146
`__main__` namespace.
147
148
Because of limitations to pickling, this means that you will not be able to \
149
resume your model outside of a namespace containing this function. In other \
150
words, you can only call `continue_training` from within this script."""
151
152
153
def dump(object_, file_, parameters=None, use_cpickle=False,
154
         protocol=DEFAULT_PROTOCOL, **kwargs):
155
    r"""Pickles an object, optionally saving its parameters separately.
156
157
    Parameters
158
    ----------
159
    object_ : object
160
        The object to pickle. If None, only the parameters passed to the
161
        `parameters` argument will be saved.
162
    file_ : file
163
        The destination for saving.
164
    parameters : list, optional
165
        Shared variables whose internal numpy arrays should be saved
166
        separately in the `_parameters` field of the tar file.
167
    pickle_object : bool
168
        If False, `object_` will not be serialized, only its parameters.
169
        This flag can be used when `object_` is not serializable, but one
170
        still want to save its parameters. Default: True
171
    use_cpickle : bool
172
        Use cPickle instead of pickle. Setting it to true will disable the
173
        warning message if you try to pickle objects from the main module,
174
        so be sure that there is no warning before turning this flag
175
        on. Default: False.
176
    protocol : int, optional
177
        The pickling protocol to use. Unlike Python's built-in pickle, the
178
        default is set to `2` instead of 0 for Python 2. The Python 3
179
        default (level 3) is maintained.
180
    \*\*kwargs
181
        Keyword arguments to be passed to `pickle.Pickler`.
182
183
    """
184
    if use_cpickle:
185
        pickler = cPickle.Pickler
186
    else:
187
        pickler = _PicklerWithWarning
188
    with closing(tarfile.TarFile(fileobj=file_, mode='w')) as tar_file:
189
        external_objects = {}
190
191
        def _save_parameters(f):
192
            renamer = _Renamer()
193
            named_parameters = {renamer(p): p for p in parameters}
194
            numpy.savez(f, **{n: p.get_value()
195
                              for n, p in named_parameters.items()})
196
            for name, p in named_parameters.items():
197
                array_ = p.container.storage[0]
198
                external_objects[id(array_)] = _mangle_parameter_name(p, name)
199
        if parameters:
200
            _taradd(_save_parameters, tar_file, '_parameters')
201
        if object_ is not None:
202
            save_object = _SaveObject(pickler, object_, external_objects,
203
                                      protocol, **kwargs)
204
            _taradd(save_object, tar_file, '_pkl')
205
206
207
def secure_dump(object_, path, dump_function=dump, **kwargs):
208
    r"""Robust serialization - does not corrupt your files when failed.
209
210
    Parameters
211
    ----------
212
    object_ : object
213
        The object to be saved to the disk.
214
    path : str
215
        The destination for saving.
216
    dump_function : function
217
        The function that is used to perform the serialization. Must take
218
        an object and file object as arguments. By default, :func:`dump` is
219
        used. An alternative would be :func:`pickle.dump`.
220
    \*\*kwargs
221
        Keyword arguments to be passed to `dump_function`.
222
223
    """
224
    try:
225
        logger.debug("Dumping object to a temporary file")
226
        with tempfile.NamedTemporaryFile(delete=False,
227
                                         dir=config.temp_dir) as temp:
228
            dump_function(object_, temp, **kwargs)
229
        logger.debug("Moving the temporary file")
230
        shutil.move(temp.name, path)
231
        logger.debug("Dump finished")
232
    except:
233
        if "temp" in locals():
234
            os.remove(temp.name)
235
        raise
236
237
238
def load(file_, name='_pkl', use_cpickle=False, **kwargs):
239
    r"""Loads an object saved using the `dump` function.
240
241
    By default, this function loads the object saved by the `dump`
242
    function. If some objects have been added to the archive using the
243
    `add_to_dump` function, then you can load them by passing their name
244
    to the `name` parameter.
245
246
    Parameters
247
    ----------
248
    file_ : file
249
        The file that contains the object to load.
250
    name : str
251
        Name of the object to load. Default is `_pkl`, meaning that it is
252
        the original object which have been dumped that is loaded.
253
    use_cpickle : bool
254
        Use cPickle instead of pickle. Default: False.
255
    \*\*kwargs
256
        Keyword arguments to be passed to `pickle.Unpickler`.
257
        Used for e.g. specifying the encoding so as to load legacy Python
258
        pickles under Python 3.x.
259
260
    Returns
261
    -------
262
    The object saved in ``file_``.
263
264
    """
265
    file_.seek(0)  # To be able to read several objects in one file
266
    if use_cpickle:
267
        unpickler = cPickle.Unpickler
268
    else:
269
        unpickler = pickle.Unpickler
270
    with tarfile.open(fileobj=file_, mode='r') as tar_file:
271
        p = unpickler(
272
            tar_file.extractfile(tar_file.getmember(name)),
273
            **kwargs
274
        )
275
        if '_parameters' in tar_file.getnames():
276
            p.persistent_load = _PersistentLoad(tar_file)
277
        return p.load()
278
279
280
def load_parameters(file_):
281
    """Loads the parameter values saved by :func:`dump`.
282
283
    This functions loads the parameters that have been saved separately by
284
    :func:`dump`, ie the ones given to its parameter `parameters`.
285
286
    Parameters
287
    ----------
288
    file_ : file
289
        The source to load the parameters from.
290
291
    Returns
292
    -------
293
    A dictionary of (parameter name, numpy array) pairs.
294
295
    """
296
    with closing(_load_parameters_npzfile(file_)) as npz_file:
297
        return {name.replace(SERIALIZATION_BRICK_DELIMITER,
298
                             BRICK_DELIMITER): value
299
                for name, value in npz_file.items()}
300
301
302
def add_to_dump(object_, file_, name, parameters=None, use_cpickle=False,
303
                protocol=DEFAULT_PROTOCOL, **kwargs):
304
    r"""Pickles an object to an existing tar archive.
305
306
    This function allows to dump more objects to an existing archive. If
307
    the object you want to dump posesses the same set of shared variables
308
    as the object already dumped, you can pass them to the `parameters`
309
    argument, which will avoid them to be serialized a second time.
310
    However, it won't work if the shared variable you pass to the
311
    `parameters` argument are not already in the archive.
312
313
    Parameters
314
    ----------
315
    object_ : object
316
        The object to pickle.
317
    file_ : file
318
        The destination for saving, opened in read-write mode (`r+`).
319
    name : str
320
        The name of the object you are dumping. It will be used as a file
321
        name in the archive. '_pkl' and '_paramters' are reserved names
322
        and can't be used.
323
    parameters : list, optional
324
        Shared variables whose internal numpy arrays should be saved
325
        separately in the `_parameters` field of the tar file. Must be a
326
        subset of the parameters already in the archive.
327
    use_cpickle : bool
328
        Use cPickle instead of pickle. Setting it to true will disable the
329
        warning message if you try to pickle objects from the main module!
330
        Be sure that you don't have the warning before turning this flag
331
        on. Default: False.
332
    protocol : int, optional
333
        The pickling protocol to use. Unlike Python's built-in pickle, the
334
        default is set to `2` instead of 0 for Python 2. The Python 3
335
        default (level 3) is maintained.
336
    \*\*kwargs
337
        Keyword arguments to be passed to `pickle.Pickler`.
338
339
    """
340
    if name in ['_pkl', '_parameters']:
341
        raise ValueError("_pkl and _parameters are reserved names and can't"
342
                         " be used as name for your object.")
343
344
    external_parameters = {}
345
    if parameters is not None:
346
        renamer = _Renamer()
347
        named_parameters = {renamer(p): p for p in parameters}
348
        for n, p in named_parameters.items():
349
            array_ = p.container.storage[0]
350
            external_parameters[id(array_)] = _mangle_parameter_name(p, n)
351
352
        # Check that the parameters are the same that the ones in the archive.
353
        file_.seek(0)  # To be able to read what is in the tar file already.
354
        with closing(tarfile.TarFile(fileobj=file_, mode='r')) as tar_file:
355
            if '_parameters' not in tar_file.getnames():
356
                raise ValueError("There is no parameters in the archive, so"
357
                                 " you can't use the argument parameters.")
358
            else:
359
                parameters = numpy.load(
360
                    tar_file.extractfile(tar_file.getmember('_parameters')))
361
                s1 = set(parameters.keys())
362
                s2 = [_unmangle_parameter_name(x)[2] for x in
363
                      external_parameters.values()]
364
                if not s1.issuperset(s2):
365
                    raise ValueError('The set of parameters is different'
366
                                     ' from the one in the archive.')
367
368
    if use_cpickle:
369
        pickler = cPickle.Pickler
370
    else:
371
        pickler = _PicklerWithWarning
372
    file_.seek(0)  # To be able to add new things in the tar file.
373
    with closing(tarfile.TarFile(fileobj=file_, mode='a')) as tar_file:
374
        save_object = _SaveObject(pickler, object_, external_parameters,
375
                                  protocol, **kwargs)
376
        _taradd(save_object, tar_file, name)
377
378
379
def continue_training(path):
380
    """Continues training using checkpoint.
381
382
    Parameters
383
    ----------
384
    path : str
385
        Path to checkpoint.
386
387
    Notes
388
    -----
389
    Python picklers can unpickle objects from global namespace only if
390
    they are present in namespace where unpickling happens. Often global
391
    functions are needed for mapping, filtering and other data stream
392
    operations. In a case if the main loop uses global objects and
393
    this function fails with a message like
394
    ```
395
    AttributeError: 'module' object has no attribute '...'
396
    ```
397
    it means that you need to import these objects.
398
399
    Examples
400
    --------
401
    This function can be used in two ways: in your script where a main
402
    loop defined or in a different script. For later options see Notes
403
    section.
404
405
    """
406
    with change_recursion_limit(config.recursion_limit):
407
        with open(path, "rb") as f:
408
            main_loop = load(f)
409
    main_loop.run()
410
411
412
def dump_and_add_to_dump(object_, file_, parameters=None, to_add=None,
413
                         use_cpickle=False, protocol=DEFAULT_PROTOCOL,
414
                         **kwargs):
415
    r"""Calls both `dump` and `add_to_dump` to serialze several objects.
416
417
    This function is used to serialize several at the same time, using
418
    persistent ID. Its main advantage is that it can be used with
419
    `secure_dump`.
420
421
    Parameters
422
    ----------
423
    object_ : object
424
        The object to pickle. If None, only the parameters passed to the
425
        `parameters` argument will be saved.
426
    file_ : file
427
        The destination for saving.
428
    parameters : list, optional
429
        Shared variables whose internal numpy arrays should be saved
430
        separately in the `_parameters` field of the tar file.
431
    to_add : dict of objects
432
        A {'name': object} dictionnary of additional objects to save in
433
        the tar archive. Its keys will be used as name in the tar file.
434
    use_cpickle : bool
435
        Use cPickle instead of pickle. Setting it to true will disable the
436
        warning message if you try to pickle objects from the main module,
437
        so be sure that there is no warning before turning this flag
438
        on. Default: False.
439
    protocol : int, optional
440
        The pickling protocol to use. Unlike Python's built-in pickle, the
441
        default is set to `2` instead of 0 for Python 2. The Python 3
442
        default (level 3) is maintained.
443
    \*\*kwargs
444
        Keyword arguments to be passed to `pickle.Pickler`.
445
446
    """
447
    dump(object_, file_, parameters=parameters, use_cpickle=use_cpickle,
448
         protocol=protocol, **kwargs)
449
    if to_add is not None:
450
        for name, obj in six.iteritems(to_add):
451
            add_to_dump(obj, file_, name, parameters=parameters,
452
                        use_cpickle=use_cpickle, protocol=protocol, **kwargs)
453
454
455
class _PicklerWithWarning(_Pickler):
456
    """Pickler that adds a warning message.
457
458
    Adds a warning message if we try to save an object referenced in the
459
    main module.
460
461
    """
462
    dispatch = _Pickler.dispatch.copy()
463
464
    def save_global(self, obj, name=None, **kwargs):
465
        module = getattr(obj, '__module__', None)
466
        if module == '__main__':
467
            warnings.warn(
468
                MAIN_MODULE_WARNING.format(kwargs.get('name', obj.__name__))
469
            )
470
        _Pickler.save_global(self, obj, name=name, **kwargs)
471
472
    dispatch[six.types.FunctionType] = save_global
473
    if six.PY2:
474
        dispatch[six.types.ClassType] = save_global
0 ignored issues
show
Bug introduced by
The Module types does not seem to have a member named ClassType.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
475
        dispatch[six.types.BuiltinFunctionType] = save_global
476
        dispatch[six.types.TypeType] = save_global
0 ignored issues
show
Bug introduced by
The Module types does not seem to have a member named TypeType.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
477
478
479
class _SaveObject(object):
480
    r"""Saves an object using Persistent ID.
481
482
    Parameters
483
    ----------
484
    pickler : object
485
        The pickler to use
486
    object_ : object
487
        The object to pickle.
488
    external_objects : dict of object
489
        The external objects to save using persistent id.
490
    protocol : int, optional
491
        The pickling protocol to use.
492
    \*\*kwargs
493
        Keyword arguments to be passed to `pickle.Pickler`.
494
495
    """
496
    def __init__(self, pickler, object_, external_objects, protocol, **kwargs):
497
        self.pickler = pickler
498
        self.object_ = object_
499
        self.external_objects = external_objects
500
        self.protocol = protocol
501
        self.kwargs = kwargs
502
503
    def __call__(self, f):
504
        p = self.pickler(f, protocol=self.protocol, **self.kwargs)
505
        p.persistent_id = _PersistentID(self.external_objects)
506
        p.dump(self.object_)
507
508
509
class _Renamer(object):
510
    """Returns a new name for the given parameter.
511
512
    It maintains a list of names already used to avoid naming
513
    collisions. It also provides names for variables without
514
    names.
515
516
    Attributes
517
    ----------
518
    used_names : set
519
        The set of names already taken.
520
    default_name : str
521
        The name to use if a parameter doesn't have a name. Default:
522
        'parameter'.
523
524
    """
525
    def __init__(self):
526
        self.used_names = set()
527
        self.default_name = 'parameter'
528
529
    def __call__(self, parameter):
530
        # Standard Blocks parameter
531
        if get_brick(parameter) is not None:
532
            name = get_brick(parameter).get_hierarchical_name(
533
                parameter, SERIALIZATION_BRICK_DELIMITER)
534
        # Shared variables with tag.name
535
        elif hasattr(parameter.tag, 'name'):
536
            name = parameter.tag.name
537
        # Standard shared variable
538
        elif parameter.name is not None:
539
            name = parameter.name
540
        # Variables without names
541
        else:
542
            name = self.default_name
543
        # Handle naming collisions
544
        if name in self.used_names:
545
            i = 2
546
            new_name = '_'.join([name, str(i)])
547
            while new_name in self.used_names:
548
                i += 1
549
                new_name = '_'.join([name, str(i)])
550
            name = new_name
551
        self.used_names.add(name)
552
        return name
553
554
555
def _recreate_numpy_ndarray(_, content):
556
    return numpy.array(content)
557
558
559
def _recreate_cuda_ndarray(_, content):
560
    return cuda_ndarray.cuda_ndarray.CudaNdarray(content)
561
562
563
def _recreate_pygpu_array(context_name, content):
564
    context = theano.gpuarray.get_context(context_name)
565
    return pygpu.gpuarray.array(content, context=context)
566
567
_ARRAY_TYPE_MAP = {numpy.ndarray: 'numpy_ndarray'}
568
_INVERSE_ARRAY_TYPE_MAP = {'numpy_ndarray': _recreate_numpy_ndarray}
569
if cuda_ndarray:
570
    _ARRAY_TYPE_MAP[cuda_ndarray.cuda_ndarray.CudaNdarray] = 'cuda_ndarray'
571
    _INVERSE_ARRAY_TYPE_MAP['cuda_ndarray'] = _recreate_cuda_ndarray
572
if pygpu:
573
    _ARRAY_TYPE_MAP[pygpu.gpuarray.GpuArray] = 'gpuarray'
574
    _INVERSE_ARRAY_TYPE_MAP['gpuarray'] = _recreate_pygpu_array
575
576
577
class _PersistentID(object):
578
    """Returns persistent identifiers for objects saved separately."""
579
    def __init__(self, external_objects):
580
        self.external_objects = external_objects
581
582
    def __call__(self, object_):
583
        return self.external_objects.get(id(object_))
584
585
586
class _PersistentLoad(object):
587
    """Loads object saved using a PersistentID mechanism."""
588
    def __init__(self, tar_file):
589
        self.tar_file = tar_file
590
        if '_parameters' in tar_file.getnames():
591
            self.parameters = numpy.load(
592
                tar_file.extractfile(tar_file.getmember('_parameters')))
593
        self._cache = {}
594
595
    def __call__(self, id_):
596
        # As we empirically found out, this method can be called multiple
597
        # times  with the same id_. That's why we need a cache here to
598
        # avoid creating the same object more than once.
599
        if id_ not in self._cache:
600
            components = _unmangle_parameter_name(id_)
601
            self._cache[id_] = components[0](
602
                components[1], self.parameters[components[2]])
603
        return self._cache[id_]
604
605
606
def _mangle_parameter_name(parameter, name):
607
    array_type = type(parameter.container.storage[0])
608
    context_name = (parameter.context_name
609
                    if pygpu and
610
                    isinstance(parameter, pygpu.gpuarray.GpuArray)
611
                    else None)
612
    if isinstance(context_name, str) and '.' in context_name:
613
        raise ValueError("context name must not contain dots")
614
    return '#1{}.{}.{}'.format(
615
        _ARRAY_TYPE_MAP[array_type], context_name, name)
616
617
618
def _unmangle_parameter_name(mangled_name):
619
    if not isinstance(mangled_name, str):
620
        # This fixes an issue with protocol 0 on Python 3 where
621
        # 'mangled_name' is a bytes object, for some reason.
622
        mangled_name = mangled_name.decode('utf8')
623
    if mangled_name.startswith('#1'):
624
        type_, context_name, name = mangled_name[2:].split('.', 2)
625
        if context_name == 'None':
626
            context_name = None
627
    elif mangled_name.startswith('#'):
628
        # Backward compatibility
629
        type_, name = mangled_name[1:].split('.', 1)
630
        context_name = None
631
    else:
632
        raise ValueError("Do not recognize the mangled parameter name")
633
    return _INVERSE_ARRAY_TYPE_MAP[type_], context_name, name
634
635
636
def _taradd(func, tar_file, name):
637
    """Adds elements dumped by the function `func` to a tar_file.
638
639
    This functions first calls the function `func` and add the file that
640
    `func` dumps to the achive `tar_file`, under the name `name`.
641
642
    Parameters
643
    ----------
644
    func : function
645
        The dumping function.
646
    tar_file : file
647
        The archive that we are filling.
648
    name : str
649
        The name of the dumped file in the archive.
650
651
    """
652
    with tempfile.NamedTemporaryFile('wb', delete=False) as temp_file:
653
        func(temp_file)
654
        temp_file.close()
655
        tar_file.add(temp_file.name, arcname=name)
656
    if os.path.isfile(temp_file.name):
657
        os.remove(temp_file.name)
658
659
660
def _load_parameters_npzfile(file_):
661
    """Loads parameters from a .npz file in a tar archive."""
662
    with tarfile.open(fileobj=file_, mode='r') as tar_file:
663
        return numpy.load(
664
            tar_file.extractfile(tar_file.getmember('_parameters')))
665