Completed
Push — master ( 510775...568e7a )
by Dmitry
01:47
created

blocks.dump_and_add_to_dump()   B

Complexity

Conditions 3

Size

Total Lines 41

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 3
dl 0
loc 41
rs 8.8571
1
"""Blocks native serialization - tar files with pickles and numpy arrays.
2
3
This module provides :func:`load` and :func:`dump` functions that can serve
4
as drop-in replacement for the respective functions from the standard
5
:mod:`pickle` module. The main differences between them and the standard
6
ones are:
7
8
    - The dump is physically a tarball, in which the pickle is stored
9
      as '_pkl' file.
10
11
    - A special file '_parameters' in the tarball can contain the data
12
      of a selected set of Theano shared variables. This data is
13
      referenced from `_pkl` using persistent id mechanism, which means
14
      that no duplication takes place. The goal here is to save the values
15
      of the parameters (this is what these shared variables are in most
16
      cases) in the most robust way possible. The actual format for
17
      '_parameters' file is the one used by :func:`numpy.savez`, i.e. a zip
18
      file of numpy arrays.
19
20
    - More objects can be dumped in the archive using the `add_to_dump`
21
      function. If the object has the same parameters as the one alread
22
      dumped, then you can avoid to dump those parameters thank to the
23
      persistent id mechanism.
24
25
    - The :func:`dump` strives to catch situations when the user tries
26
      to pickle a function or a class not defined in the global namespace
27
      and give a meaningful warning.
28
29
If briefly, this module proposes a dumping mechanism which allows for
30
greater robustness and persistency than standard pickling.
31
32
Examples
33
--------
34
Consider a standard main loop (without an algorithm and a data stream
35
for brevity)
36
37
>>> from theano import tensor
38
>>> from blocks.main_loop import MainLoop
39
>>> from blocks.bricks import MLP, Tanh, Softmax
40
>>> from blocks.model import Model
41
>>> mlp = MLP([Tanh(), None], [784, 10, 10])
42
>>> x = tensor.matrix('features')
43
>>> y = tensor.lmatrix('targets')
44
>>> cost = Softmax().categorical_cross_entropy(
45
...            y.flatten(), mlp.apply(tensor.flatten(x, outdim=2)))
46
>>> main_loop = MainLoop(None, None, model=Model(cost))
47
48
Let's see how the main loop is dumped by :func:`dump`
49
50
>>> from blocks.serialization import dump, load
51
>>> import tarfile
52
>>> with open('main_loop.tar', 'wb') as dst:
53
...     dump(main_loop, dst)
54
>>> tarball = tarfile.open('main_loop.tar', 'r')
55
>>> tarball # doctest: +ELLIPSIS
56
<tarfile.TarFile object at ...>
57
>>> tarball.getnames()
58
['_pkl']
59
>>> tarball.close()
60
61
As promised, the dump is a tarball. Since we did not ask for any additional
62
magic, it just contains the pickled main loop in '_pkl' file.
63
64
Let's do something more interesting:
65
66
>>> with open('main_loop.tar', 'wb') as dst:
67
...     dump(main_loop, dst,
68
...          parameters=main_loop.model.parameters)
69
>>> tarball = tarfile.open('main_loop.tar', 'r')
70
>>> tarball.getnames()
71
['_parameters', '_pkl']
72
73
As requested by specifying the `_parameters` argument, the parameters were
74
saved in a zip file.
75
76
>>> import numpy
77
>>> ps = numpy.load(tarball.extractfile(tarball.getmember('_parameters')))
78
>>> sorted(ps.keys()) # doctest: +ELLIPSIS
79
['|mlp|linear_0.W', '|mlp|linear_0.b', '|mlp|linear_1.W', '|mlp|lin...]
80
>>> ps.close()
81
82
The names for parameters are chosen intellegently to reflect their
83
position in the brick hierarchy, if they belong to bricks, and by
84
simply using the `.name` attribute, if they do not.
85
86
The loading of the main loop as a whole still works:
87
88
>>> with open('main_loop.tar', 'rb') as src:
89
...     main_loop_loaded = load(src)
90
>>> main_loop_loaded # doctest: +ELLIPSIS
91
<blocks.main_loop.MainLoop object at ...>
92
93
Additionally, this module provides convenience routine
94
:func:`load_parameters`:
95
96
>>> with open('main_loop.tar', 'rb') as src:
97
...     parameters = load_parameters(src)
98
>>> sorted(parameters.keys()) # doctest: +ELLIPSIS
99
['/mlp/linear_0.W', '/mlp/linear_0.b', '/mlp/linear_1.W', '/mlp/line...]
100
101
Loading parameters saved by :func:`dump` with :func:`load_parameters`
102
ensures that their heirarchical names are compatible with
103
:class:`~blocks.model.Model` and :class:`~blocks.select.Selector` classes.
104
105
TODO: Add information about :func:`add_to_dump`.
106
107
"""
108
import numpy
109
import os
110
import pickle
111
import shutil
112
import six
113
import tarfile
114
import tempfile
115
import warnings
116
117
from contextlib import closing
118
from pickle import HIGHEST_PROTOCOL
119
try:
120
    from pickle import DEFAULT_PROTOCOL
121
    from pickle import _Pickler
122
except ImportError:
123
    DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
124
    from pickle import Pickler as _Pickler
125
from six.moves import cPickle
126
try:
127
    from theano.sandbox.cuda import cuda_ndarray
128
except ImportError:
129
    cuda_ndarray = None
130
131
from blocks.config import config
132
from blocks.filter import get_brick
133
from blocks.utils import change_recursion_limit
134
135
136
BRICK_DELIMITER = '|'
137
MAIN_MODULE_WARNING = """WARNING: Main loop depends on the function `{}` in \
138
`__main__` namespace.
139
140
Because of limitations to pickling, this means that you will not be able to \
141
resume your model outside of a namespace containing this function. In other \
142
words, you can only call `continue_training` from within this script."""
143
144
145
def dump(object_, file_, parameters=None, use_cpickle=False,
146
         protocol=DEFAULT_PROTOCOL, **kwargs):
147
    r"""Pickles an object, optionally saving its parameters separately.
148
149
    Parameters
150
    ----------
151
    object_ : object
152
        The object to pickle. If None, only the parameters passed to the
153
        `parameters` argument will be saved.
154
    file_ : file
155
        The destination for saving.
156
    parameters : list, optional
157
        Shared variables whose internal numpy arrays should be saved
158
        separately in the `_parameters` field of the tar file.
159
    pickle_object : bool
160
        If False, `object_` will not be serialized, only its parameters.
161
        This flag can be used when `object_` is not serializable, but one
162
        still want to save its parameters. Default: True
163
    use_cpickle : bool
164
        Use cPickle instead of pickle. Setting it to true will disable the
165
        warning message if you try to pickle objects from the main module,
166
        so be sure that there is no warning before turning this flag
167
        on. Default: False.
168
    protocol : int, optional
169
        The pickling protocol to use. Unlike Python's built-in pickle, the
170
        default is set to `2` instead of 0 for Python 2. The Python 3
171
        default (level 3) is maintained.
172
    \*\*kwargs
173
        Keyword arguments to be passed to `pickle.Pickler`.
174
175
    """
176
    if use_cpickle:
177
        pickler = cPickle.Pickler
178
    else:
179
        pickler = _PicklerWithWarning
180
    with closing(tarfile.TarFile(fileobj=file_, mode='w')) as tar_file:
181
        external_objects = {}
182
183
        def _save_parameters(f):
184
            renamer = _Renamer()
185
            named_parameters = {renamer(p): p for p in parameters}
186
            numpy.savez(f, **{n: p.get_value()
187
                              for n, p in named_parameters.items()})
188
            for n, p in named_parameters.items():
189
                array_ = p.container.storage[0]
190
                external_objects[id(array_)] = _mangle_parameter_name(
191
                    type(array_), n)
192
        if parameters:
193
            _taradd(_save_parameters, tar_file, '_parameters')
194
        if object_ is not None:
195
            save_object = _SaveObject(pickler, object_, external_objects,
196
                                      protocol, **kwargs)
197
            _taradd(save_object, tar_file, '_pkl')
198
199
200
def secure_dump(object_, path, dump_function=dump, **kwargs):
201
    r"""Robust serialization - does not corrupt your files when failed.
202
203
    Parameters
204
    ----------
205
    object_ : object
206
        The object to be saved to the disk.
207
    path : str
208
        The destination for saving.
209
    dump_function : function
210
        The function that is used to perform the serialization. Must take
211
        an object and file object as arguments. By default, :func:`dump` is
212
        used. An alternative would be :func:`pickle.dump`.
213
    \*\*kwargs
214
        Keyword arguments to be passed to `dump_function`.
215
216
    """
217
    try:
218
        with tempfile.NamedTemporaryFile(delete=False,
219
                                         dir=config.temp_dir) as temp:
220
            dump_function(object_, temp, **kwargs)
221
        shutil.move(temp.name, path)
222
    except:
223
        if "temp" in locals():
224
            os.remove(temp.name)
225
        raise
226
227
228
def load(file_, name='_pkl', use_cpickle=False):
229
    """Loads an object saved using the `dump` function.
230
231
    By default, this function loads the object saved by the `dump`
232
    function. If some objects have been added to the archive using the
233
    `add_to_dump` function, then you can load them by passing their name
234
    to the `name` parameter.
235
236
    Parameters
237
    ----------
238
    file_ : file
239
        The file that contains the object to load.
240
    name : str
241
        Name of the object to load. Default is `_pkl`, meaning that it is
242
        the original object which have been dumped that is loaded.
243
    use_cpickle : bool
244
        Use cPickle instead of pickle. Default: False.
245
246
    Returns
247
    -------
248
    The object saved in file_.
249
250
    """
251
    file_.seek(0)  # To be able to read several objects in one file
252
    if use_cpickle:
253
        unpickler = cPickle.Unpickler
254
    else:
255
        unpickler = pickle.Unpickler
256
    with tarfile.open(fileobj=file_, mode='r') as tar_file:
257
        p = unpickler(
258
            tar_file.extractfile(tar_file.getmember(name)))
259
        if '_parameters' in tar_file.getnames():
260
            p.persistent_load = _PersistentLoad(tar_file)
261
        return p.load()
262
263
264
def load_parameters(file_):
265
    """Loads the parameter values saved by :func:`dump`.
266
267
    This functions loads the parameters that have been saved separately by
268
    :func:`dump`, ie the ones given to its parameter `parameters`.
269
270
    Parameters
271
    ----------
272
    file_ : file
273
        The source to load the parameters from.
274
275
    Returns
276
    -------
277
    A dictionary of (parameter name, numpy array) pairs.
278
279
    """
280
    with closing(_load_parameters_npzfile(file_)) as npz_file:
281
        return {name.replace(BRICK_DELIMITER, '/'): value
282
                for name, value in npz_file.items()}
283
284
285
def add_to_dump(object_, file_, name, parameters=None, use_cpickle=False,
286
                protocol=DEFAULT_PROTOCOL, **kwargs):
287
    r"""Pickles an object to an existing tar archive.
288
289
    This function allows to dump more objects to an existing archive. If
290
    the object you want to dump posesses the same set of shared variables
291
    as the object already dumped, you can pass them to the `parameters`
292
    argument, which will avoid them to be serialized a second time.
293
    However, it won't work if the shared variable you pass to the
294
    `parameters` argument are not already in the archive.
295
296
    Parameters
297
    ----------
298
    object_ : object
299
        The object to pickle.
300
    file_ : file
301
        The destination for saving, opened in read-write mode (`r+`).
302
    name : str
303
        The name of the object you are dumping. It will be used as a file
304
        name in the archive. '_pkl' and '_paramters' are reserved names
305
        and can't be used.
306
    parameters : list, optional
307
        Shared variables whose internal numpy arrays should be saved
308
        separately in the `_parameters` field of the tar file. Must be a
309
        subset of the parameters already in the archive.
310
    use_cpickle : bool
311
        Use cPickle instead of pickle. Setting it to true will disable the
312
        warning message if you try to pickle objects from the main module!
313
        Be sure that you don't have the warning before turning this flag
314
        on. Default: False.
315
    protocol : int, optional
316
        The pickling protocol to use. Unlike Python's built-in pickle, the
317
        default is set to `2` instead of 0 for Python 2. The Python 3
318
        default (level 3) is maintained.
319
    \*\*kwargs
320
        Keyword arguments to be passed to `pickle.Pickler`.
321
322
    """
323
    if name in ['_pkl', '_parameters']:
324
        raise ValueError("_pkl and _parameters are reserved names and can't" \
325
                         " be used as name for your object.")
326
327
    external_parameters = {}
328
    if parameters is not None:
329
        renamer = _Renamer()
330
        named_parameters = {renamer(p): p for p in parameters}
331
        for n, p in named_parameters.items():
332
            array_ = p.container.storage[0]
333
            external_parameters[id(array_)] = _mangle_parameter_name(
334
                type(array_), n)
335
336
        # Check that the parameters are the same that the ones in the archive.
337
        file_.seek(0)  # To be able to read what is in the tar file already.
338
        with closing(tarfile.TarFile(fileobj=file_, mode='r')) as tar_file:
339
            if '_parameters' not in tar_file.getnames():
340
                raise ValueError("There is no parameters in the archive, so" \
341
                                 " you can't use the argument parameters.")
342
            else:
343
                parameters = numpy.load(
344
                    tar_file.extractfile(tar_file.getmember('_parameters')))
345
                s1 = set(parameters.keys())
346
                s2 = [_unmangle_parameter_name(x)[1] for x in
347
                      external_parameters.values()]
348
                if not s1.issuperset(s2):
349
                    raise ValueError('The set of parameters is different' \
350
                                     ' from the one in the archive.')
351
352
    if use_cpickle:
353
        pickler = cPickle.Pickler
354
    else:
355
        pickler = _PicklerWithWarning
356
    file_.seek(0)  # To be able to add new things in the tar file.
357
    with closing(tarfile.TarFile(fileobj=file_, mode='a')) as tar_file:
358
        save_object = _SaveObject(pickler, object_, external_parameters,
359
                                  protocol, **kwargs)
360
        _taradd(save_object, tar_file, name)
361
362
363
def continue_training(path):
364
    """Continues training using checkpoint.
365
366
    Parameters
367
    ----------
368
    path : str
369
        Path to checkpoint.
370
371
    Notes
372
    -----
373
    Python picklers can unpickle objects from global namespace only if
374
    they are present in namespace where unpickling happens. Often global
375
    functions are needed for mapping, filtering and other data stream
376
    operations. In a case if the main loop uses global objects and
377
    this function fails with a message like
378
    ```
379
    AttributeError: 'module' object has no attribute '...'
380
    ```
381
    it means that you need to import these objects.
382
383
    Examples
384
    --------
385
    This function can be used in two ways: in your script where a main
386
    loop defined or in a different script. For later options see Notes
387
    section.
388
389
    """
390
    with change_recursion_limit(config.recursion_limit):
391
        with open(path, "rb") as f:
392
            main_loop = load(f)
393
    main_loop.run()
394
395
396
def dump_and_add_to_dump(object_, file_, parameters=None, to_add=None,
397
                         use_cpickle=False, protocol=DEFAULT_PROTOCOL,
398
                         **kwargs):
399
    r"""Calls both `dump` and `add_to_dump` to serialze several objects.
400
401
    This function is used to serialize several at the same time, using
402
    persistent ID. Its main advantage is that it can be used with
403
    `secure_dump`.
404
405
    Parameters
406
    ----------
407
    object_ : object
408
        The object to pickle. If None, only the parameters passed to the
409
        `parameters` argument will be saved.
410
    file_ : file
411
        The destination for saving.
412
    parameters : list, optional
413
        Shared variables whose internal numpy arrays should be saved
414
        separately in the `_parameters` field of the tar file.
415
    to_add : dict of objects
416
        A {'name': object} dictionnary of additional objects to save in
417
        the tar archive. Its keys will be used as name in the tar file.
418
    use_cpickle : bool
419
        Use cPickle instead of pickle. Setting it to true will disable the
420
        warning message if you try to pickle objects from the main module,
421
        so be sure that there is no warning before turning this flag
422
        on. Default: False.
423
    protocol : int, optional
424
        The pickling protocol to use. Unlike Python's built-in pickle, the
425
        default is set to `2` instead of 0 for Python 2. The Python 3
426
        default (level 3) is maintained.
427
    \*\*kwargs
428
        Keyword arguments to be passed to `pickle.Pickler`.
429
430
    """
431
    dump(object_, file_, parameters=parameters, use_cpickle=use_cpickle,
432
         protocol=protocol, **kwargs)
433
    if to_add is not None:
434
        for name, obj in six.iteritems(to_add):
435
            add_to_dump(obj, file_, name, parameters=parameters,
436
                        use_cpickle=use_cpickle, protocol=protocol, **kwargs)
437
438
439
class _PicklerWithWarning(_Pickler):
440
    """Pickler that adds a warning message.
441
442
    Adds a warning message if we try to save an object referenced in the
443
    main module.
444
445
    """
446
    dispatch = _Pickler.dispatch.copy()
447
448
    def save_global(self, obj, name=None, **kwargs):
449
        module = getattr(obj, '__module__', None)
450
        if module == '__main__':
451
            warnings.warn(
452
                MAIN_MODULE_WARNING.format(kwargs.get('name', obj.__name__))
453
            )
454
        _Pickler.save_global(self, obj, name=name, **kwargs)
455
456
    dispatch[six.types.FunctionType] = save_global
457
    if six.PY2:
458
        dispatch[six.types.ClassType] = save_global
0 ignored issues
show
Bug introduced by
The Module types does not seem to have a member named ClassType.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
459
        dispatch[six.types.BuiltinFunctionType] = save_global
460
        dispatch[six.types.TypeType] = save_global
0 ignored issues
show
Bug introduced by
The Module types does not seem to have a member named TypeType.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
461
462
463
class _SaveObject(object):
464
    r"""Saves an object using Persistent ID.
465
466
    Parameters
467
    ----------
468
    pickler : object
469
        The pickler to use
470
    object_ : object
471
        The object to pickle.
472
    external_objects : dict of object
473
        The external objects to save using persistent id.
474
    protocol : int, optional
475
        The pickling protocol to use.
476
    \*\*kwargs
477
        Keyword arguments to be passed to `pickle.Pickler`.
478
479
    """
480
    def __init__(self, pickler, object_, external_objects, protocol, **kwargs):
481
        self.pickler = pickler
482
        self.object_ = object_
483
        self.external_objects = external_objects
484
        self.protocol = protocol
485
        self.kwargs = kwargs
486
487
    def __call__(self, f):
488
        p = self.pickler(f, protocol=self.protocol, **self.kwargs)
489
        p.persistent_id = _PersistentID(self.external_objects)
490
        p.dump(self.object_)
491
492
493
class _Renamer(object):
494
    """Returns a new name for the given parameter.
495
496
    It maintains a list of names already used to avoid naming
497
    collisions. It also provides names for variables without
498
    names.
499
500
    Attributes
501
    ----------
502
    used_names : set
503
        The set of names already taken.
504
    default_name : str
505
        The name to use if a parameter doesn't have a name. Default:
506
        'parameter'.
507
508
    """
509
    def __init__(self):
510
        self.used_names = set()
511
        self.default_name = 'parameter'
512
513
    def __call__(self, parameter):
514
        # Standard Blocks parameter
515
        if get_brick(parameter) is not None:
516
            name = '{}.{}'.format(
517
                BRICK_DELIMITER.join(
518
                    [""] + [brick.name for brick in
519
                            get_brick(parameter).get_unique_path()]),
520
                parameter.name)
521
        # Shared variables with tag.name
522
        elif hasattr(parameter.tag, 'name'):
523
            name = parameter.tag.name
524
        # Standard shared variable
525
        elif parameter.name is not None:
526
            name = parameter.name
527
        # Variables without names
528
        else:
529
            name = self.default_name
530
        # Handle naming collisions
531
        if name in self.used_names:
532
            i = 2
533
            new_name = '_'.join([name, str(i)])
534
            while new_name in self.used_names:
535
                i += 1
536
                new_name = '_'.join([name, str(i)])
537
            name = new_name
538
        self.used_names.add(name)
539
        return name
540
541
542
_ARRAY_TYPE_MAP = {numpy.ndarray: 'numpy_ndarray'}
543
_INVERSE_ARRAY_TYPE_MAP = {'numpy_ndarray': numpy.array}
544
if cuda_ndarray:
545
    _ARRAY_TYPE_MAP[cuda_ndarray.cuda_ndarray.CudaNdarray] = 'cuda_ndarray'
546
    _INVERSE_ARRAY_TYPE_MAP['cuda_ndarray'] = \
547
        cuda_ndarray.cuda_ndarray.CudaNdarray
548
549
550
class _PersistentID(object):
551
    """Returns persistent identifiers for objects saved separately."""
552
    def __init__(self, external_objects):
553
        self.external_objects = external_objects
554
555
    def __call__(self, object_):
556
        return self.external_objects.get(id(object_))
557
558
559
class _PersistentLoad(object):
560
    """Loads object saved using a PersistentID mechanism."""
561
    def __init__(self, tar_file):
562
        self.tar_file = tar_file
563
        if '_parameters' in tar_file.getnames():
564
            self.parameters = numpy.load(
565
                tar_file.extractfile(tar_file.getmember('_parameters')))
566
567
    def __call__(self, id_):
568
        components = _unmangle_parameter_name(id_)
569
        return components[0](self.parameters[components[1]])
570
571
572
def _mangle_parameter_name(type_, name):
573
    return '#{}.{}'.format(_ARRAY_TYPE_MAP[type_], name)
574
575
576
def _unmangle_parameter_name(mangled_name):
577
    if mangled_name.startswith('#'):
578
        type_, name = mangled_name[1:].split('.', 1)
579
        return _INVERSE_ARRAY_TYPE_MAP[type_], name
580
581
582
def _taradd(func, tar_file, name):
583
    """Adds elements dumped by the function `func` to a tar_file.
584
585
    This functions first calls the function `func` and add the file that
586
    `func` dumps to the achive `tar_file`, under the name `name`.
587
588
    Parameters
589
    ----------
590
    func : function
591
        The dumping function.
592
    tar_file : file
593
        The archive that we are filling.
594
    name : str
595
        The name of the dumped file in the archive.
596
597
    """
598
    with tempfile.NamedTemporaryFile('wb', delete=False) as temp_file:
599
        func(temp_file)
600
        temp_file.close()
601
        tar_file.add(temp_file.name, arcname=name)
602
    if os.path.isfile(temp_file.name):
603
        os.remove(temp_file.name)
604
605
606
def _load_parameters_npzfile(file_):
607
    """Loads parameters from a .npz file in a tar archive."""
608
    with tarfile.open(fileobj=file_, mode='r') as tar_file:
609
        return numpy.load(
610
            tar_file.extractfile(tar_file.getmember('_parameters')))
611