Completed
Push — master ( 41dd4a...dba974 )
by Dmitry
03:30
created

_unmangle_parameter_name()   B

Complexity

Conditions 5

Size

Total Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 5
c 1
b 0
f 0
dl 0
loc 16
rs 8.5454
1
"""Blocks native serialization - tar files with pickles and numpy arrays.
2
3
This module provides :func:`load` and :func:`dump` functions that can serve
4
as drop-in replacement for the respective functions from the standard
5
:mod:`pickle` module. The main differences between them and the standard
6
ones are:
7
8
    - The dump is physically a tarball, in which the pickle is stored
9
      as '_pkl' file.
10
11
    - A special file '_parameters' in the tarball can contain the data
12
      of a selected set of Theano shared variables. This data is
13
      referenced from `_pkl` using persistent id mechanism, which means
14
      that no duplication takes place. The goal here is to save the values
15
      of the parameters (this is what these shared variables are in most
16
      cases) in the most robust way possible. The actual format for
17
      '_parameters' file is the one used by :func:`numpy.savez`, i.e. a zip
18
      file of numpy arrays.
19
20
    - More objects can be dumped in the archive using the `add_to_dump`
21
      function. If the object has the same parameters as the one alread
22
      dumped, then you can avoid to dump those parameters thank to the
23
      persistent id mechanism.
24
25
    - The :func:`dump` strives to catch situations when the user tries
26
      to pickle a function or a class not defined in the global namespace
27
      and give a meaningful warning.
28
29
If briefly, this module proposes a dumping mechanism which allows for
30
greater robustness and persistency than standard pickling.
31
32
Examples
33
--------
34
Consider a standard main loop (without an algorithm and a data stream
35
for brevity)
36
37
>>> from theano import tensor
38
>>> from blocks.main_loop import MainLoop
39
>>> from blocks.bricks import MLP, Tanh, Softmax
40
>>> from blocks.model import Model
41
>>> mlp = MLP([Tanh(), None], [784, 10, 10])
42
>>> x = tensor.matrix('features')
43
>>> y = tensor.lmatrix('targets')
44
>>> cost = Softmax().categorical_cross_entropy(
45
...            y.flatten(), mlp.apply(tensor.flatten(x, outdim=2)))
46
>>> main_loop = MainLoop(None, None, model=Model(cost))
47
48
Let's see how the main loop is dumped by :func:`dump`
49
50
>>> from blocks.serialization import dump, load
51
>>> import tarfile
52
>>> with open('main_loop.tar', 'wb') as dst:
53
...     dump(main_loop, dst)
54
>>> tarball = tarfile.open('main_loop.tar', 'r')
55
>>> tarball # doctest: +ELLIPSIS
56
<tarfile.TarFile object at ...>
57
>>> tarball.getnames()
58
['_pkl']
59
>>> tarball.close()
60
61
As promised, the dump is a tarball. Since we did not ask for any additional
62
magic, it just contains the pickled main loop in '_pkl' file.
63
64
Let's do something more interesting:
65
66
>>> with open('main_loop.tar', 'wb') as dst:
67
...     dump(main_loop, dst,
68
...          parameters=main_loop.model.parameters)
69
>>> tarball = tarfile.open('main_loop.tar', 'r')
70
>>> tarball.getnames()
71
['_parameters', '_pkl']
72
73
As requested by specifying the `_parameters` argument, the parameters were
74
saved in a zip file.
75
76
>>> import numpy
77
>>> ps = numpy.load(tarball.extractfile(tarball.getmember('_parameters')))
78
>>> sorted(ps.keys()) # doctest: +ELLIPSIS
79
['|mlp|linear_0.W', '|mlp|linear_0.b', '|mlp|linear_1.W', '|mlp|lin...]
80
>>> ps.close()
81
82
The names for parameters are chosen intellegently to reflect their
83
position in the brick hierarchy, if they belong to bricks, and by
84
simply using the `.name` attribute, if they do not.
85
86
The loading of the main loop as a whole still works:
87
88
>>> with open('main_loop.tar', 'rb') as src:
89
...     main_loop_loaded = load(src)
90
>>> main_loop_loaded # doctest: +ELLIPSIS
91
<blocks.main_loop.MainLoop object at ...>
92
93
Additionally, this module provides convenience routine
94
:func:`load_parameters`:
95
96
>>> with open('main_loop.tar', 'rb') as src:
97
...     parameters = load_parameters(src)
98
>>> sorted(parameters.keys()) # doctest: +ELLIPSIS
99
['/mlp/linear_0.W', '/mlp/linear_0.b', '/mlp/linear_1.W', '/mlp/line...]
100
101
Loading parameters saved by :func:`dump` with :func:`load_parameters`
102
ensures that their heirarchical names are compatible with
103
:class:`~blocks.model.Model` and :class:`~blocks.select.Selector` classes.
104
105
TODO: Add information about :func:`add_to_dump`.
106
107
"""
108
import numpy
109
import os
110
import pickle
111
import shutil
112
import six
113
import tarfile
114
import tempfile
115
import warnings
116
import logging
117
from contextlib import closing
118
from pickle import HIGHEST_PROTOCOL
119
try:
120
    from pickle import DEFAULT_PROTOCOL
121
    from pickle import _Pickler
122
except ImportError:
123
    DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
124
    from pickle import Pickler as _Pickler
125
126
from six.moves import cPickle
127
import theano
128
try:
129
    from theano.sandbox.cuda import cuda_ndarray
130
except ImportError:
131
    cuda_ndarray = None
132
try:
133
    import pygpu
134
except:
0 ignored issues
show
Coding Style Best Practice introduced by
General except handlers without types should be used sparingly.

Typically, you would use general except handlers when you intend to specifically handle all types of errors, f.e. when logging. Otherwise, such general error handlers can mask errors in your application that you want to know of.

Loading history...
135
    pygpu = None
0 ignored issues
show
Coding Style Naming introduced by
The name pygpu does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
136
from blocks.config import config
137
from blocks.filter import get_brick
138
from blocks.utils import change_recursion_limit
139
140
141
logger = logging.getLogger(__name__)
142
143
BRICK_DELIMITER = '|'
144
MAIN_MODULE_WARNING = """WARNING: Main loop depends on the function `{}` in \
145
`__main__` namespace.
146
147
Because of limitations to pickling, this means that you will not be able to \
148
resume your model outside of a namespace containing this function. In other \
149
words, you can only call `continue_training` from within this script."""
150
151
152
def dump(object_, file_, parameters=None, use_cpickle=False,
153
         protocol=DEFAULT_PROTOCOL, **kwargs):
154
    r"""Pickles an object, optionally saving its parameters separately.
155
156
    Parameters
157
    ----------
158
    object_ : object
159
        The object to pickle. If None, only the parameters passed to the
160
        `parameters` argument will be saved.
161
    file_ : file
162
        The destination for saving.
163
    parameters : list, optional
164
        Shared variables whose internal numpy arrays should be saved
165
        separately in the `_parameters` field of the tar file.
166
    pickle_object : bool
167
        If False, `object_` will not be serialized, only its parameters.
168
        This flag can be used when `object_` is not serializable, but one
169
        still want to save its parameters. Default: True
170
    use_cpickle : bool
171
        Use cPickle instead of pickle. Setting it to true will disable the
172
        warning message if you try to pickle objects from the main module,
173
        so be sure that there is no warning before turning this flag
174
        on. Default: False.
175
    protocol : int, optional
176
        The pickling protocol to use. Unlike Python's built-in pickle, the
177
        default is set to `2` instead of 0 for Python 2. The Python 3
178
        default (level 3) is maintained.
179
    \*\*kwargs
180
        Keyword arguments to be passed to `pickle.Pickler`.
181
182
    """
183
    if use_cpickle:
184
        pickler = cPickle.Pickler
185
    else:
186
        pickler = _PicklerWithWarning
187
    with closing(tarfile.TarFile(fileobj=file_, mode='w')) as tar_file:
188
        external_objects = {}
189
190
        def _save_parameters(f):
191
            renamer = _Renamer()
192
            named_parameters = {renamer(p): p for p in parameters}
193
            numpy.savez(f, **{n: p.get_value()
194
                              for n, p in named_parameters.items()})
195
            for name, p in named_parameters.items():
196
                array_ = p.container.storage[0]
197
                external_objects[id(array_)] = _mangle_parameter_name(p, name)
198
        if parameters:
199
            _taradd(_save_parameters, tar_file, '_parameters')
200
        if object_ is not None:
201
            save_object = _SaveObject(pickler, object_, external_objects,
202
                                      protocol, **kwargs)
203
            _taradd(save_object, tar_file, '_pkl')
204
205
206
def secure_dump(object_, path, dump_function=dump, **kwargs):
207
    r"""Robust serialization - does not corrupt your files when failed.
208
209
    Parameters
210
    ----------
211
    object_ : object
212
        The object to be saved to the disk.
213
    path : str
214
        The destination for saving.
215
    dump_function : function
216
        The function that is used to perform the serialization. Must take
217
        an object and file object as arguments. By default, :func:`dump` is
218
        used. An alternative would be :func:`pickle.dump`.
219
    \*\*kwargs
220
        Keyword arguments to be passed to `dump_function`.
221
222
    """
223
    try:
224
        logger.debug("Dumping object to a temporary file")
225
        with tempfile.NamedTemporaryFile(delete=False,
226
                                         dir=config.temp_dir) as temp:
227
            dump_function(object_, temp, **kwargs)
228
        logger.debug("Moving the temporary file")
229
        shutil.move(temp.name, path)
230
        logger.debug("Dump finished")
231
    except:
232
        if "temp" in locals():
233
            os.remove(temp.name)
234
        raise
235
236
237
def load(file_, name='_pkl', use_cpickle=False):
238
    """Loads an object saved using the `dump` function.
239
240
    By default, this function loads the object saved by the `dump`
241
    function. If some objects have been added to the archive using the
242
    `add_to_dump` function, then you can load them by passing their name
243
    to the `name` parameter.
244
245
    Parameters
246
    ----------
247
    file_ : file
248
        The file that contains the object to load.
249
    name : str
250
        Name of the object to load. Default is `_pkl`, meaning that it is
251
        the original object which have been dumped that is loaded.
252
    use_cpickle : bool
253
        Use cPickle instead of pickle. Default: False.
254
255
    Returns
256
    -------
257
    The object saved in file_.
258
259
    """
260
    file_.seek(0)  # To be able to read several objects in one file
261
    if use_cpickle:
262
        unpickler = cPickle.Unpickler
263
    else:
264
        unpickler = pickle.Unpickler
265
    with tarfile.open(fileobj=file_, mode='r') as tar_file:
266
        p = unpickler(
267
            tar_file.extractfile(tar_file.getmember(name)))
268
        if '_parameters' in tar_file.getnames():
269
            p.persistent_load = _PersistentLoad(tar_file)
270
        return p.load()
271
272
273
def load_parameters(file_):
274
    """Loads the parameter values saved by :func:`dump`.
275
276
    This functions loads the parameters that have been saved separately by
277
    :func:`dump`, ie the ones given to its parameter `parameters`.
278
279
    Parameters
280
    ----------
281
    file_ : file
282
        The source to load the parameters from.
283
284
    Returns
285
    -------
286
    A dictionary of (parameter name, numpy array) pairs.
287
288
    """
289
    with closing(_load_parameters_npzfile(file_)) as npz_file:
290
        return {name.replace(BRICK_DELIMITER, '/'): value
291
                for name, value in npz_file.items()}
292
293
294
def add_to_dump(object_, file_, name, parameters=None, use_cpickle=False,
295
                protocol=DEFAULT_PROTOCOL, **kwargs):
296
    r"""Pickles an object to an existing tar archive.
297
298
    This function allows to dump more objects to an existing archive. If
299
    the object you want to dump posesses the same set of shared variables
300
    as the object already dumped, you can pass them to the `parameters`
301
    argument, which will avoid them to be serialized a second time.
302
    However, it won't work if the shared variable you pass to the
303
    `parameters` argument are not already in the archive.
304
305
    Parameters
306
    ----------
307
    object_ : object
308
        The object to pickle.
309
    file_ : file
310
        The destination for saving, opened in read-write mode (`r+`).
311
    name : str
312
        The name of the object you are dumping. It will be used as a file
313
        name in the archive. '_pkl' and '_paramters' are reserved names
314
        and can't be used.
315
    parameters : list, optional
316
        Shared variables whose internal numpy arrays should be saved
317
        separately in the `_parameters` field of the tar file. Must be a
318
        subset of the parameters already in the archive.
319
    use_cpickle : bool
320
        Use cPickle instead of pickle. Setting it to true will disable the
321
        warning message if you try to pickle objects from the main module!
322
        Be sure that you don't have the warning before turning this flag
323
        on. Default: False.
324
    protocol : int, optional
325
        The pickling protocol to use. Unlike Python's built-in pickle, the
326
        default is set to `2` instead of 0 for Python 2. The Python 3
327
        default (level 3) is maintained.
328
    \*\*kwargs
329
        Keyword arguments to be passed to `pickle.Pickler`.
330
331
    """
332
    if name in ['_pkl', '_parameters']:
333
        raise ValueError("_pkl and _parameters are reserved names and can't"
334
                         " be used as name for your object.")
335
336
    external_parameters = {}
337
    if parameters is not None:
338
        renamer = _Renamer()
339
        named_parameters = {renamer(p): p for p in parameters}
340
        for n, p in named_parameters.items():
341
            array_ = p.container.storage[0]
342
            external_parameters[id(array_)] = _mangle_parameter_name(p, n)
343
344
        # Check that the parameters are the same that the ones in the archive.
345
        file_.seek(0)  # To be able to read what is in the tar file already.
346
        with closing(tarfile.TarFile(fileobj=file_, mode='r')) as tar_file:
347
            if '_parameters' not in tar_file.getnames():
348
                raise ValueError("There is no parameters in the archive, so"
349
                                 " you can't use the argument parameters.")
350
            else:
351
                parameters = numpy.load(
352
                    tar_file.extractfile(tar_file.getmember('_parameters')))
353
                s1 = set(parameters.keys())
354
                s2 = [_unmangle_parameter_name(x)[2] for x in
355
                      external_parameters.values()]
356
                if not s1.issuperset(s2):
357
                    raise ValueError('The set of parameters is different'
358
                                     ' from the one in the archive.')
359
360
    if use_cpickle:
361
        pickler = cPickle.Pickler
362
    else:
363
        pickler = _PicklerWithWarning
364
    file_.seek(0)  # To be able to add new things in the tar file.
365
    with closing(tarfile.TarFile(fileobj=file_, mode='a')) as tar_file:
366
        save_object = _SaveObject(pickler, object_, external_parameters,
367
                                  protocol, **kwargs)
368
        _taradd(save_object, tar_file, name)
369
370
371
def continue_training(path):
372
    """Continues training using checkpoint.
373
374
    Parameters
375
    ----------
376
    path : str
377
        Path to checkpoint.
378
379
    Notes
380
    -----
381
    Python picklers can unpickle objects from global namespace only if
382
    they are present in namespace where unpickling happens. Often global
383
    functions are needed for mapping, filtering and other data stream
384
    operations. In a case if the main loop uses global objects and
385
    this function fails with a message like
386
    ```
387
    AttributeError: 'module' object has no attribute '...'
388
    ```
389
    it means that you need to import these objects.
390
391
    Examples
392
    --------
393
    This function can be used in two ways: in your script where a main
394
    loop defined or in a different script. For later options see Notes
395
    section.
396
397
    """
398
    with change_recursion_limit(config.recursion_limit):
399
        with open(path, "rb") as f:
400
            main_loop = load(f)
401
    main_loop.run()
402
403
404
def dump_and_add_to_dump(object_, file_, parameters=None, to_add=None,
405
                         use_cpickle=False, protocol=DEFAULT_PROTOCOL,
406
                         **kwargs):
407
    r"""Calls both `dump` and `add_to_dump` to serialze several objects.
408
409
    This function is used to serialize several at the same time, using
410
    persistent ID. Its main advantage is that it can be used with
411
    `secure_dump`.
412
413
    Parameters
414
    ----------
415
    object_ : object
416
        The object to pickle. If None, only the parameters passed to the
417
        `parameters` argument will be saved.
418
    file_ : file
419
        The destination for saving.
420
    parameters : list, optional
421
        Shared variables whose internal numpy arrays should be saved
422
        separately in the `_parameters` field of the tar file.
423
    to_add : dict of objects
424
        A {'name': object} dictionnary of additional objects to save in
425
        the tar archive. Its keys will be used as name in the tar file.
426
    use_cpickle : bool
427
        Use cPickle instead of pickle. Setting it to true will disable the
428
        warning message if you try to pickle objects from the main module,
429
        so be sure that there is no warning before turning this flag
430
        on. Default: False.
431
    protocol : int, optional
432
        The pickling protocol to use. Unlike Python's built-in pickle, the
433
        default is set to `2` instead of 0 for Python 2. The Python 3
434
        default (level 3) is maintained.
435
    \*\*kwargs
436
        Keyword arguments to be passed to `pickle.Pickler`.
437
438
    """
439
    dump(object_, file_, parameters=parameters, use_cpickle=use_cpickle,
440
         protocol=protocol, **kwargs)
441
    if to_add is not None:
442
        for name, obj in six.iteritems(to_add):
443
            add_to_dump(obj, file_, name, parameters=parameters,
444
                        use_cpickle=use_cpickle, protocol=protocol, **kwargs)
445
446
447
class _PicklerWithWarning(_Pickler):
448
    """Pickler that adds a warning message.
449
450
    Adds a warning message if we try to save an object referenced in the
451
    main module.
452
453
    """
454
    dispatch = _Pickler.dispatch.copy()
455
456
    def save_global(self, obj, name=None, **kwargs):
457
        module = getattr(obj, '__module__', None)
458
        if module == '__main__':
459
            warnings.warn(
460
                MAIN_MODULE_WARNING.format(kwargs.get('name', obj.__name__))
461
            )
462
        _Pickler.save_global(self, obj, name=name, **kwargs)
463
464
    dispatch[six.types.FunctionType] = save_global
465
    if six.PY2:
466
        dispatch[six.types.ClassType] = save_global
0 ignored issues
show
Bug introduced by
The Module types does not seem to have a member named ClassType.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
467
        dispatch[six.types.BuiltinFunctionType] = save_global
468
        dispatch[six.types.TypeType] = save_global
0 ignored issues
show
Bug introduced by
The Module types does not seem to have a member named TypeType.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
469
470
471
class _SaveObject(object):
472
    r"""Saves an object using Persistent ID.
473
474
    Parameters
475
    ----------
476
    pickler : object
477
        The pickler to use
478
    object_ : object
479
        The object to pickle.
480
    external_objects : dict of object
481
        The external objects to save using persistent id.
482
    protocol : int, optional
483
        The pickling protocol to use.
484
    \*\*kwargs
485
        Keyword arguments to be passed to `pickle.Pickler`.
486
487
    """
488
    def __init__(self, pickler, object_, external_objects, protocol, **kwargs):
489
        self.pickler = pickler
490
        self.object_ = object_
491
        self.external_objects = external_objects
492
        self.protocol = protocol
493
        self.kwargs = kwargs
494
495
    def __call__(self, f):
496
        p = self.pickler(f, protocol=self.protocol, **self.kwargs)
497
        p.persistent_id = _PersistentID(self.external_objects)
498
        p.dump(self.object_)
499
500
501
class _Renamer(object):
502
    """Returns a new name for the given parameter.
503
504
    It maintains a list of names already used to avoid naming
505
    collisions. It also provides names for variables without
506
    names.
507
508
    Attributes
509
    ----------
510
    used_names : set
511
        The set of names already taken.
512
    default_name : str
513
        The name to use if a parameter doesn't have a name. Default:
514
        'parameter'.
515
516
    """
517
    def __init__(self):
518
        self.used_names = set()
519
        self.default_name = 'parameter'
520
521
    def __call__(self, parameter):
522
        # Standard Blocks parameter
523
        if get_brick(parameter) is not None:
524
            name = '{}.{}'.format(
525
                BRICK_DELIMITER.join(
526
                    [""] + [brick.name for brick in
527
                            get_brick(parameter).get_unique_path()]),
528
                parameter.name)
529
        # Shared variables with tag.name
530
        elif hasattr(parameter.tag, 'name'):
531
            name = parameter.tag.name
532
        # Standard shared variable
533
        elif parameter.name is not None:
534
            name = parameter.name
535
        # Variables without names
536
        else:
537
            name = self.default_name
538
        # Handle naming collisions
539
        if name in self.used_names:
540
            i = 2
541
            new_name = '_'.join([name, str(i)])
542
            while new_name in self.used_names:
543
                i += 1
544
                new_name = '_'.join([name, str(i)])
545
            name = new_name
546
        self.used_names.add(name)
547
        return name
548
549
550
def _recreate_numpy_ndarray(_, content):
551
    return numpy.array(content)
552
553
554
def _recreate_cuda_ndarray(_, content):
555
    return cuda_ndarray.cuda_ndarray.CudaNdarray(content)
556
557
558
def _recreate_pygpu_array(context_name, content):
559
    context = theano.sandbox.gpuarray.get_context(context_name)
560
    return pygpu.gpuarray.array(content, context=context)
561
562
_ARRAY_TYPE_MAP = {numpy.ndarray: 'numpy_ndarray'}
563
_INVERSE_ARRAY_TYPE_MAP = {'numpy_ndarray': _recreate_numpy_ndarray}
564
if cuda_ndarray:
565
    _ARRAY_TYPE_MAP[cuda_ndarray.cuda_ndarray.CudaNdarray] = 'cuda_ndarray'
566
    _INVERSE_ARRAY_TYPE_MAP['cuda_ndarray'] = _recreate_cuda_ndarray
567
if pygpu:
568
    _ARRAY_TYPE_MAP[pygpu.gpuarray.GpuArray] = 'gpuarray'
569
    _INVERSE_ARRAY_TYPE_MAP['gpuarray'] = _recreate_pygpu_array
570
571
572
class _PersistentID(object):
573
    """Returns persistent identifiers for objects saved separately."""
574
    def __init__(self, external_objects):
575
        self.external_objects = external_objects
576
577
    def __call__(self, object_):
578
        return self.external_objects.get(id(object_))
579
580
581
class _PersistentLoad(object):
582
    """Loads object saved using a PersistentID mechanism."""
583
    def __init__(self, tar_file):
584
        self.tar_file = tar_file
585
        if '_parameters' in tar_file.getnames():
586
            self.parameters = numpy.load(
587
                tar_file.extractfile(tar_file.getmember('_parameters')))
588
        self._cache = {}
589
590
    def __call__(self, id_):
591
        # As we empirically found out, this method can be called multiple
592
        # times  with the same id_. That's why we need a cache here to
593
        # avoid creating the same object more than once.
594
        if id_ not in self._cache:
595
            components = _unmangle_parameter_name(id_)
596
            self._cache[id_] = components[0](
597
                components[1], self.parameters[components[2]])
598
        return self._cache[id_]
599
600
601
def _mangle_parameter_name(parameter, name):
602
    array_type = type(parameter.container.storage[0])
603
    context_name = (parameter.context_name
604
                    if pygpu and
605
                    isinstance(parameter, pygpu.gpuarray.GpuArray)
606
                    else None)
607
    if isinstance(context_name, str) and '.' in context_name:
608
        raise ValueError("context name must not contain dots")
609
    return '#1{}.{}.{}'.format(
610
        _ARRAY_TYPE_MAP[array_type], context_name, name)
611
612
613
def _unmangle_parameter_name(mangled_name):
614
    if not isinstance(mangled_name, str):
615
        # This fixes an issue with protocol 0 on Python 3 where
616
        # 'mangled_name' is a bytes object, for some reason.
617
        mangled_name = mangled_name.decode('utf8')
618
    if mangled_name.startswith('#1'):
619
        type_, context_name, name = mangled_name[2:].split('.', 2)
620
        if context_name == 'None':
621
            context_name = None
622
    elif mangled_name.startswith('#'):
623
        # Backward compatibility
624
        type_, name = mangled_name[1:].split('.', 1)
625
        context_name = None
626
    else:
627
        raise ValueError("Do not recognize the mangled parameter name")
628
    return _INVERSE_ARRAY_TYPE_MAP[type_], context_name, name
629
630
631
def _taradd(func, tar_file, name):
632
    """Adds elements dumped by the function `func` to a tar_file.
633
634
    This functions first calls the function `func` and add the file that
635
    `func` dumps to the achive `tar_file`, under the name `name`.
636
637
    Parameters
638
    ----------
639
    func : function
640
        The dumping function.
641
    tar_file : file
642
        The archive that we are filling.
643
    name : str
644
        The name of the dumped file in the archive.
645
646
    """
647
    with tempfile.NamedTemporaryFile('wb', delete=False) as temp_file:
648
        func(temp_file)
649
        temp_file.close()
650
        tar_file.add(temp_file.name, arcname=name)
651
    if os.path.isfile(temp_file.name):
652
        os.remove(temp_file.name)
653
654
655
def _load_parameters_npzfile(file_):
656
    """Loads parameters from a .npz file in a tar archive."""
657
    with tarfile.open(fileobj=file_, mode='r') as tar_file:
658
        return numpy.load(
659
            tar_file.extractfile(tar_file.getmember('_parameters')))
660