Completed
Pull Request — master (#1069)
by Dmitry
04:33
created

_recreate_pygpu_array()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 3
rs 10
1
"""Blocks native serialization - tar files with pickles and numpy arrays.
2
3
This module provides :func:`load` and :func:`dump` functions that can serve
4
as drop-in replacement for the respective functions from the standard
5
:mod:`pickle` module. The main differences between them and the standard
6
ones are:
7
8
    - The dump is physically a tarball, in which the pickle is stored
9
      as '_pkl' file.
10
11
    - A special file '_parameters' in the tarball can contain the data
12
      of a selected set of Theano shared variables. This data is
13
      referenced from `_pkl` using persistent id mechanism, which means
14
      that no duplication takes place. The goal here is to save the values
15
      of the parameters (this is what these shared variables are in most
16
      cases) in the most robust way possible. The actual format for
17
      '_parameters' file is the one used by :func:`numpy.savez`, i.e. a zip
18
      file of numpy arrays.
19
20
    - More objects can be dumped in the archive using the `add_to_dump`
21
      function. If the object has the same parameters as the one alread
22
      dumped, then you can avoid to dump those parameters thank to the
23
      persistent id mechanism.
24
25
    - The :func:`dump` strives to catch situations when the user tries
26
      to pickle a function or a class not defined in the global namespace
27
      and give a meaningful warning.
28
29
If briefly, this module proposes a dumping mechanism which allows for
30
greater robustness and persistency than standard pickling.
31
32
Examples
33
--------
34
Consider a standard main loop (without an algorithm and a data stream
35
for brevity)
36
37
>>> from theano import tensor
38
>>> from blocks.main_loop import MainLoop
39
>>> from blocks.bricks import MLP, Tanh, Softmax
40
>>> from blocks.model import Model
41
>>> mlp = MLP([Tanh(), None], [784, 10, 10])
42
>>> x = tensor.matrix('features')
43
>>> y = tensor.lmatrix('targets')
44
>>> cost = Softmax().categorical_cross_entropy(
45
...            y.flatten(), mlp.apply(tensor.flatten(x, outdim=2)))
46
>>> main_loop = MainLoop(None, None, model=Model(cost))
47
48
Let's see how the main loop is dumped by :func:`dump`
49
50
>>> from blocks.serialization import dump, load
51
>>> import tarfile
52
>>> with open('main_loop.tar', 'wb') as dst:
53
...     dump(main_loop, dst)
54
>>> tarball = tarfile.open('main_loop.tar', 'r')
55
>>> tarball # doctest: +ELLIPSIS
56
<tarfile.TarFile object at ...>
57
>>> tarball.getnames()
58
['_pkl']
59
>>> tarball.close()
60
61
As promised, the dump is a tarball. Since we did not ask for any additional
62
magic, it just contains the pickled main loop in '_pkl' file.
63
64
Let's do something more interesting:
65
66
>>> with open('main_loop.tar', 'wb') as dst:
67
...     dump(main_loop, dst,
68
...          parameters=main_loop.model.parameters)
69
>>> tarball = tarfile.open('main_loop.tar', 'r')
70
>>> tarball.getnames()
71
['_parameters', '_pkl']
72
73
As requested by specifying the `_parameters` argument, the parameters were
74
saved in a zip file.
75
76
>>> import numpy
77
>>> ps = numpy.load(tarball.extractfile(tarball.getmember('_parameters')))
78
>>> sorted(ps.keys()) # doctest: +ELLIPSIS
79
['|mlp|linear_0.W', '|mlp|linear_0.b', '|mlp|linear_1.W', '|mlp|lin...]
80
>>> ps.close()
81
82
The names for parameters are chosen intellegently to reflect their
83
position in the brick hierarchy, if they belong to bricks, and by
84
simply using the `.name` attribute, if they do not.
85
86
The loading of the main loop as a whole still works:
87
88
>>> with open('main_loop.tar', 'rb') as src:
89
...     main_loop_loaded = load(src)
90
>>> main_loop_loaded # doctest: +ELLIPSIS
91
<blocks.main_loop.MainLoop object at ...>
92
93
Additionally, this module provides convenience routine
94
:func:`load_parameters`:
95
96
>>> with open('main_loop.tar', 'rb') as src:
97
...     parameters = load_parameters(src)
98
>>> sorted(parameters.keys()) # doctest: +ELLIPSIS
99
['/mlp/linear_0.W', '/mlp/linear_0.b', '/mlp/linear_1.W', '/mlp/line...]
100
101
Loading parameters saved by :func:`dump` with :func:`load_parameters`
102
ensures that their heirarchical names are compatible with
103
:class:`~blocks.model.Model` and :class:`~blocks.select.Selector` classes.
104
105
TODO: Add information about :func:`add_to_dump`.
106
107
"""
108
import numpy
109
import os
110
import pickle
111
import shutil
112
import six
113
import tarfile
114
import tempfile
115
import warnings
116
import logging
117
from contextlib import closing
118
from pickle import HIGHEST_PROTOCOL
119
try:
120
    from pickle import DEFAULT_PROTOCOL
121
    from pickle import _Pickler
122
except ImportError:
123
    DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
124
    from pickle import Pickler as _Pickler
125
126
from six.moves import cPickle
127
import theano
128
try:
129
    from theano.sandbox.cuda import cuda_ndarray
130
except ImportError:
131
    cuda_ndarray = None
132
try:
133
    import pygpu
134
except:
0 ignored issues
show
Coding Style Best Practice introduced by
General except handlers without types should be used sparingly.

Typically, you would use general except handlers when you intend to specifically handle all types of errors, f.e. when logging. Otherwise, such general error handlers can mask errors in your application that you want to know of.

Loading history...
135
    pygpu = None
0 ignored issues
show
Coding Style Naming introduced by
The name pygpu does not conform to the constant naming conventions ((([A-Z_][A-Z0-9_]*)|(__.*__))$).

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
136
from blocks.config import config
137
from blocks.filter import get_brick
138
from blocks.utils import change_recursion_limit
139
140
141
logger = logging.getLogger(__name__)
142
143
BRICK_DELIMITER = '|'
144
MAIN_MODULE_WARNING = """WARNING: Main loop depends on the function `{}` in \
145
`__main__` namespace.
146
147
Because of limitations to pickling, this means that you will not be able to \
148
resume your model outside of a namespace containing this function. In other \
149
words, you can only call `continue_training` from within this script."""
150
151
152
def dump(object_, file_, parameters=None, use_cpickle=False,
153
         protocol=DEFAULT_PROTOCOL, **kwargs):
154
    r"""Pickles an object, optionally saving its parameters separately.
155
156
    Parameters
157
    ----------
158
    object_ : object
159
        The object to pickle. If None, only the parameters passed to the
160
        `parameters` argument will be saved.
161
    file_ : file
162
        The destination for saving.
163
    parameters : list, optional
164
        Shared variables whose internal numpy arrays should be saved
165
        separately in the `_parameters` field of the tar file.
166
    pickle_object : bool
167
        If False, `object_` will not be serialized, only its parameters.
168
        This flag can be used when `object_` is not serializable, but one
169
        still want to save its parameters. Default: True
170
    use_cpickle : bool
171
        Use cPickle instead of pickle. Setting it to true will disable the
172
        warning message if you try to pickle objects from the main module,
173
        so be sure that there is no warning before turning this flag
174
        on. Default: False.
175
    protocol : int, optional
176
        The pickling protocol to use. Unlike Python's built-in pickle, the
177
        default is set to `2` instead of 0 for Python 2. The Python 3
178
        default (level 3) is maintained.
179
    \*\*kwargs
180
        Keyword arguments to be passed to `pickle.Pickler`.
181
182
    """
183
    if use_cpickle:
184
        pickler = cPickle.Pickler
185
    else:
186
        pickler = _PicklerWithWarning
187
    with closing(tarfile.TarFile(fileobj=file_, mode='w')) as tar_file:
188
        external_objects = {}
189
190
        def _save_parameters(f):
191
            renamer = _Renamer()
192
            named_parameters = {renamer(p): p for p in parameters}
193
            numpy.savez(f, **{n: p.get_value()
194
                              for n, p in named_parameters.items()})
195
            for name, p in named_parameters.items():
196
                array_ = p.container.storage[0]
197
                context_name = (p.context_name
198
                                if isinstance(p, pygpu.gpuarray.GpuArray)
199
                                else None)
200
                external_objects[id(array_)] = _mangle_parameter_name(
201
                    type(array_), context_name, name)
202
        if parameters:
203
            _taradd(_save_parameters, tar_file, '_parameters')
204
        if object_ is not None:
205
            save_object = _SaveObject(pickler, object_, external_objects,
206
                                      protocol, **kwargs)
207
            _taradd(save_object, tar_file, '_pkl')
208
209
210
def secure_dump(object_, path, dump_function=dump, **kwargs):
211
    r"""Robust serialization - does not corrupt your files when failed.
212
213
    Parameters
214
    ----------
215
    object_ : object
216
        The object to be saved to the disk.
217
    path : str
218
        The destination for saving.
219
    dump_function : function
220
        The function that is used to perform the serialization. Must take
221
        an object and file object as arguments. By default, :func:`dump` is
222
        used. An alternative would be :func:`pickle.dump`.
223
    \*\*kwargs
224
        Keyword arguments to be passed to `dump_function`.
225
226
    """
227
    try:
228
        logger.debug("Dumping object to a temporary file")
229
        with tempfile.NamedTemporaryFile(delete=False,
230
                                         dir=config.temp_dir) as temp:
231
            dump_function(object_, temp, **kwargs)
232
        logger.debug("Moving the temporary file")
233
        shutil.move(temp.name, path)
234
        logger.debug("Dump finished")
235
    except:
236
        if "temp" in locals():
237
            os.remove(temp.name)
238
        raise
239
240
241
def load(file_, name='_pkl', use_cpickle=False):
242
    """Loads an object saved using the `dump` function.
243
244
    By default, this function loads the object saved by the `dump`
245
    function. If some objects have been added to the archive using the
246
    `add_to_dump` function, then you can load them by passing their name
247
    to the `name` parameter.
248
249
    Parameters
250
    ----------
251
    file_ : file
252
        The file that contains the object to load.
253
    name : str
254
        Name of the object to load. Default is `_pkl`, meaning that it is
255
        the original object which have been dumped that is loaded.
256
    use_cpickle : bool
257
        Use cPickle instead of pickle. Default: False.
258
259
    Returns
260
    -------
261
    The object saved in file_.
262
263
    """
264
    file_.seek(0)  # To be able to read several objects in one file
265
    if use_cpickle:
266
        unpickler = cPickle.Unpickler
267
    else:
268
        unpickler = pickle.Unpickler
269
    with tarfile.open(fileobj=file_, mode='r') as tar_file:
270
        p = unpickler(
271
            tar_file.extractfile(tar_file.getmember(name)))
272
        if '_parameters' in tar_file.getnames():
273
            p.persistent_load = _PersistentLoad(tar_file)
274
        return p.load()
275
276
277
def load_parameters(file_):
278
    """Loads the parameter values saved by :func:`dump`.
279
280
    This functions loads the parameters that have been saved separately by
281
    :func:`dump`, ie the ones given to its parameter `parameters`.
282
283
    Parameters
284
    ----------
285
    file_ : file
286
        The source to load the parameters from.
287
288
    Returns
289
    -------
290
    A dictionary of (parameter name, numpy array) pairs.
291
292
    """
293
    with closing(_load_parameters_npzfile(file_)) as npz_file:
294
        return {name.replace(BRICK_DELIMITER, '/'): value
295
                for name, value in npz_file.items()}
296
297
298
def add_to_dump(object_, file_, name, parameters=None, use_cpickle=False,
299
                protocol=DEFAULT_PROTOCOL, **kwargs):
300
    r"""Pickles an object to an existing tar archive.
301
302
    This function allows to dump more objects to an existing archive. If
303
    the object you want to dump posesses the same set of shared variables
304
    as the object already dumped, you can pass them to the `parameters`
305
    argument, which will avoid them to be serialized a second time.
306
    However, it won't work if the shared variable you pass to the
307
    `parameters` argument are not already in the archive.
308
309
    Parameters
310
    ----------
311
    object_ : object
312
        The object to pickle.
313
    file_ : file
314
        The destination for saving, opened in read-write mode (`r+`).
315
    name : str
316
        The name of the object you are dumping. It will be used as a file
317
        name in the archive. '_pkl' and '_paramters' are reserved names
318
        and can't be used.
319
    parameters : list, optional
320
        Shared variables whose internal numpy arrays should be saved
321
        separately in the `_parameters` field of the tar file. Must be a
322
        subset of the parameters already in the archive.
323
    use_cpickle : bool
324
        Use cPickle instead of pickle. Setting it to true will disable the
325
        warning message if you try to pickle objects from the main module!
326
        Be sure that you don't have the warning before turning this flag
327
        on. Default: False.
328
    protocol : int, optional
329
        The pickling protocol to use. Unlike Python's built-in pickle, the
330
        default is set to `2` instead of 0 for Python 2. The Python 3
331
        default (level 3) is maintained.
332
    \*\*kwargs
333
        Keyword arguments to be passed to `pickle.Pickler`.
334
335
    """
336
    if name in ['_pkl', '_parameters']:
337
        raise ValueError("_pkl and _parameters are reserved names and can't"
338
                         " be used as name for your object.")
339
340
    external_parameters = {}
341
    if parameters is not None:
342
        renamer = _Renamer()
343
        named_parameters = {renamer(p): p for p in parameters}
344
        for n, p in named_parameters.items():
345
            array_ = p.container.storage[0]
346
            external_parameters[id(array_)] = _mangle_parameter_name(
347
                type(array_), n)
348
349
        # Check that the parameters are the same that the ones in the archive.
350
        file_.seek(0)  # To be able to read what is in the tar file already.
351
        with closing(tarfile.TarFile(fileobj=file_, mode='r')) as tar_file:
352
            if '_parameters' not in tar_file.getnames():
353
                raise ValueError("There is no parameters in the archive, so"
354
                                 " you can't use the argument parameters.")
355
            else:
356
                parameters = numpy.load(
357
                    tar_file.extractfile(tar_file.getmember('_parameters')))
358
                s1 = set(parameters.keys())
359
                s2 = [_unmangle_parameter_name(x)[2] for x in
360
                      external_parameters.values()]
361
                if not s1.issuperset(s2):
362
                    raise ValueError('The set of parameters is different'
363
                                     ' from the one in the archive.')
364
365
    if use_cpickle:
366
        pickler = cPickle.Pickler
367
    else:
368
        pickler = _PicklerWithWarning
369
    file_.seek(0)  # To be able to add new things in the tar file.
370
    with closing(tarfile.TarFile(fileobj=file_, mode='a')) as tar_file:
371
        save_object = _SaveObject(pickler, object_, external_parameters,
372
                                  protocol, **kwargs)
373
        _taradd(save_object, tar_file, name)
374
375
376
def continue_training(path):
377
    """Continues training using checkpoint.
378
379
    Parameters
380
    ----------
381
    path : str
382
        Path to checkpoint.
383
384
    Notes
385
    -----
386
    Python picklers can unpickle objects from global namespace only if
387
    they are present in namespace where unpickling happens. Often global
388
    functions are needed for mapping, filtering and other data stream
389
    operations. In a case if the main loop uses global objects and
390
    this function fails with a message like
391
    ```
392
    AttributeError: 'module' object has no attribute '...'
393
    ```
394
    it means that you need to import these objects.
395
396
    Examples
397
    --------
398
    This function can be used in two ways: in your script where a main
399
    loop defined or in a different script. For later options see Notes
400
    section.
401
402
    """
403
    with change_recursion_limit(config.recursion_limit):
404
        with open(path, "rb") as f:
405
            main_loop = load(f)
406
    main_loop.run()
407
408
409
def dump_and_add_to_dump(object_, file_, parameters=None, to_add=None,
410
                         use_cpickle=False, protocol=DEFAULT_PROTOCOL,
411
                         **kwargs):
412
    r"""Calls both `dump` and `add_to_dump` to serialze several objects.
413
414
    This function is used to serialize several at the same time, using
415
    persistent ID. Its main advantage is that it can be used with
416
    `secure_dump`.
417
418
    Parameters
419
    ----------
420
    object_ : object
421
        The object to pickle. If None, only the parameters passed to the
422
        `parameters` argument will be saved.
423
    file_ : file
424
        The destination for saving.
425
    parameters : list, optional
426
        Shared variables whose internal numpy arrays should be saved
427
        separately in the `_parameters` field of the tar file.
428
    to_add : dict of objects
429
        A {'name': object} dictionnary of additional objects to save in
430
        the tar archive. Its keys will be used as name in the tar file.
431
    use_cpickle : bool
432
        Use cPickle instead of pickle. Setting it to true will disable the
433
        warning message if you try to pickle objects from the main module,
434
        so be sure that there is no warning before turning this flag
435
        on. Default: False.
436
    protocol : int, optional
437
        The pickling protocol to use. Unlike Python's built-in pickle, the
438
        default is set to `2` instead of 0 for Python 2. The Python 3
439
        default (level 3) is maintained.
440
    \*\*kwargs
441
        Keyword arguments to be passed to `pickle.Pickler`.
442
443
    """
444
    dump(object_, file_, parameters=parameters, use_cpickle=use_cpickle,
445
         protocol=protocol, **kwargs)
446
    if to_add is not None:
447
        for name, obj in six.iteritems(to_add):
448
            add_to_dump(obj, file_, name, parameters=parameters,
449
                        use_cpickle=use_cpickle, protocol=protocol, **kwargs)
450
451
452
class _PicklerWithWarning(_Pickler):
453
    """Pickler that adds a warning message.
454
455
    Adds a warning message if we try to save an object referenced in the
456
    main module.
457
458
    """
459
    dispatch = _Pickler.dispatch.copy()
460
461
    def save_global(self, obj, name=None, **kwargs):
462
        module = getattr(obj, '__module__', None)
463
        if module == '__main__':
464
            warnings.warn(
465
                MAIN_MODULE_WARNING.format(kwargs.get('name', obj.__name__))
466
            )
467
        _Pickler.save_global(self, obj, name=name, **kwargs)
468
469
    dispatch[six.types.FunctionType] = save_global
470
    if six.PY2:
471
        dispatch[six.types.ClassType] = save_global
0 ignored issues
show
Bug introduced by
The Module types does not seem to have a member named ClassType.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
472
        dispatch[six.types.BuiltinFunctionType] = save_global
473
        dispatch[six.types.TypeType] = save_global
0 ignored issues
show
Bug introduced by
The Module types does not seem to have a member named TypeType.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
474
475
476
class _SaveObject(object):
477
    r"""Saves an object using Persistent ID.
478
479
    Parameters
480
    ----------
481
    pickler : object
482
        The pickler to use
483
    object_ : object
484
        The object to pickle.
485
    external_objects : dict of object
486
        The external objects to save using persistent id.
487
    protocol : int, optional
488
        The pickling protocol to use.
489
    \*\*kwargs
490
        Keyword arguments to be passed to `pickle.Pickler`.
491
492
    """
493
    def __init__(self, pickler, object_, external_objects, protocol, **kwargs):
494
        self.pickler = pickler
495
        self.object_ = object_
496
        self.external_objects = external_objects
497
        self.protocol = protocol
498
        self.kwargs = kwargs
499
500
    def __call__(self, f):
501
        p = self.pickler(f, protocol=self.protocol, **self.kwargs)
502
        p.persistent_id = _PersistentID(self.external_objects)
503
        p.dump(self.object_)
504
505
506
class _Renamer(object):
507
    """Returns a new name for the given parameter.
508
509
    It maintains a list of names already used to avoid naming
510
    collisions. It also provides names for variables without
511
    names.
512
513
    Attributes
514
    ----------
515
    used_names : set
516
        The set of names already taken.
517
    default_name : str
518
        The name to use if a parameter doesn't have a name. Default:
519
        'parameter'.
520
521
    """
522
    def __init__(self):
523
        self.used_names = set()
524
        self.default_name = 'parameter'
525
526
    def __call__(self, parameter):
527
        # Standard Blocks parameter
528
        if get_brick(parameter) is not None:
529
            name = '{}.{}'.format(
530
                BRICK_DELIMITER.join(
531
                    [""] + [brick.name for brick in
532
                            get_brick(parameter).get_unique_path()]),
533
                parameter.name)
534
        # Shared variables with tag.name
535
        elif hasattr(parameter.tag, 'name'):
536
            name = parameter.tag.name
537
        # Standard shared variable
538
        elif parameter.name is not None:
539
            name = parameter.name
540
        # Variables without names
541
        else:
542
            name = self.default_name
543
        # Handle naming collisions
544
        if name in self.used_names:
545
            i = 2
546
            new_name = '_'.join([name, str(i)])
547
            while new_name in self.used_names:
548
                i += 1
549
                new_name = '_'.join([name, str(i)])
550
            name = new_name
551
        self.used_names.add(name)
552
        return name
553
554
def _recreate_numpy_ndarray(_, content):
555
    return numpy.array(content)
556
557
def _recreate_cuda_ndarray(_, content):
558
    return cuda_ndarray.cuda_ndarray.CudaNdarray(content)
559
560
def _recreate_pygpu_array(context_name, content):
561
    context = theano.sandbox.gpuarray.get_context(context_name)
562
    return pygpu.gpuarray.array(content, context=context)
563
564
_ARRAY_TYPE_MAP = {numpy.ndarray: 'numpy_ndarray'}
565
_INVERSE_ARRAY_TYPE_MAP = {'numpy_ndarray': _recreate_numpy_ndarray}
566
if cuda_ndarray:
567
    _ARRAY_TYPE_MAP[cuda_ndarray.cuda_ndarray.CudaNdarray] = 'cuda_ndarray'
568
    _INVERSE_ARRAY_TYPE_MAP['cuda_ndarray'] = _recreate_cuda_ndarray
569
if pygpu:
570
    _ARRAY_TYPE_MAP[pygpu.gpuarray.GpuArray] = 'gpuarray'
571
    _INVERSE_ARRAY_TYPE_MAP['gpuarray'] = _recreate_pygpu_array
572
573
574
class _PersistentID(object):
575
    """Returns persistent identifiers for objects saved separately."""
576
    def __init__(self, external_objects):
577
        self.external_objects = external_objects
578
579
    def __call__(self, object_):
580
        return self.external_objects.get(id(object_))
581
582
583
class _PersistentLoad(object):
584
    """Loads object saved using a PersistentID mechanism."""
585
    def __init__(self, tar_file):
586
        self.tar_file = tar_file
587
        if '_parameters' in tar_file.getnames():
588
            self.parameters = numpy.load(
589
                tar_file.extractfile(tar_file.getmember('_parameters')))
590
        self._cache = {}
591
592
    def __call__(self, id_):
593
        # As we empirically found out, this method can be called multiple
594
        # times  with the same id_. That's why we need a cache here to
595
        # avoid creating the same object more than once.
596
        if id_ not in self._cache:
597
            components = _unmangle_parameter_name(id_)
598
            self._cache[id_] = components[0](
599
                components[1], self.parameters[components[2]])
600
        return self._cache[id_]
601
602
603
def _mangle_parameter_name(type_, context_name, name):
604
    if isinstance(context_name, str) and '.' in context_name:
605
        raise ValueError("context name must not contain dots")
606
    return '#1{}.{}.{}'.format(_ARRAY_TYPE_MAP[type_], context_name, name)
607
608
609
def _unmangle_parameter_name(mangled_name):
610
    if mangled_name.startswith('#1'):
611
        type_, context_name, name = mangled_name[2:].split('.', 2)
612
        if context_name == 'None':
613
            context_name = None
614
    elif mangled_name.startswith('#'):
615
        # Backward compatibility
616
        type_, name = mangled_name[1:].split('.', 1)
617
        context_name = None
618
    else:
619
        raise ValueError("Do not recognize the mangled parameter name")
620
    return _INVERSE_ARRAY_TYPE_MAP[type_], context_name, name
621
622
623
def _taradd(func, tar_file, name):
624
    """Adds elements dumped by the function `func` to a tar_file.
625
626
    This functions first calls the function `func` and add the file that
627
    `func` dumps to the achive `tar_file`, under the name `name`.
628
629
    Parameters
630
    ----------
631
    func : function
632
        The dumping function.
633
    tar_file : file
634
        The archive that we are filling.
635
    name : str
636
        The name of the dumped file in the archive.
637
638
    """
639
    with tempfile.NamedTemporaryFile('wb', delete=False) as temp_file:
640
        func(temp_file)
641
        temp_file.close()
642
        tar_file.add(temp_file.name, arcname=name)
643
    if os.path.isfile(temp_file.name):
644
        os.remove(temp_file.name)
645
646
647
def _load_parameters_npzfile(file_):
648
    """Loads parameters from a .npz file in a tar archive."""
649
    with tarfile.open(fileobj=file_, mode='r') as tar_file:
650
        return numpy.load(
651
            tar_file.extractfile(tar_file.getmember('_parameters')))
652