Completed
Pull Request — master (#1064)
by Dmitry
04:46
created

add_to_dump()   F

Complexity

Conditions 11

Size

Total Lines 76

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 11
dl 0
loc 76
rs 3.3333

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like add_to_dump() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Blocks native serialization - tar files with pickles and numpy arrays.
2
3
This module provides :func:`load` and :func:`dump` functions that can serve
4
as drop-in replacement for the respective functions from the standard
5
:mod:`pickle` module. The main differences between them and the standard
6
ones are:
7
8
    - The dump is physically a tarball, in which the pickle is stored
9
      as '_pkl' file.
10
11
    - A special file '_parameters' in the tarball can contain the data
12
      of a selected set of Theano shared variables. This data is
13
      referenced from `_pkl` using persistent id mechanism, which means
14
      that no duplication takes place. The goal here is to save the values
15
      of the parameters (this is what these shared variables are in most
16
      cases) in the most robust way possible. The actual format for
17
      '_parameters' file is the one used by :func:`numpy.savez`, i.e. a zip
18
      file of numpy arrays.
19
20
    - More objects can be dumped in the archive using the `add_to_dump`
21
      function. If the object has the same parameters as the one alread
22
      dumped, then you can avoid to dump those parameters thank to the
23
      persistent id mechanism.
24
25
    - The :func:`dump` strives to catch situations when the user tries
26
      to pickle a function or a class not defined in the global namespace
27
      and give a meaningful warning.
28
29
If briefly, this module proposes a dumping mechanism which allows for
30
greater robustness and persistency than standard pickling.
31
32
Examples
33
--------
34
Consider a standard main loop (without an algorithm and a data stream
35
for brevity)
36
37
>>> from theano import tensor
38
>>> from blocks.main_loop import MainLoop
39
>>> from blocks.bricks import MLP, Tanh, Softmax
40
>>> from blocks.model import Model
41
>>> mlp = MLP([Tanh(), None], [784, 10, 10])
42
>>> x = tensor.matrix('features')
43
>>> y = tensor.lmatrix('targets')
44
>>> cost = Softmax().categorical_cross_entropy(
45
...            y.flatten(), mlp.apply(tensor.flatten(x, outdim=2)))
46
>>> main_loop = MainLoop(None, None, model=Model(cost))
47
48
Let's see how the main loop is dumped by :func:`dump`
49
50
>>> from blocks.serialization import dump, load
51
>>> import tarfile
52
>>> with open('main_loop.tar', 'wb') as dst:
53
...     dump(main_loop, dst)
54
>>> tarball = tarfile.open('main_loop.tar', 'r')
55
>>> tarball # doctest: +ELLIPSIS
56
<tarfile.TarFile object at ...>
57
>>> tarball.getnames()
58
['_pkl']
59
>>> tarball.close()
60
61
As promised, the dump is a tarball. Since we did not ask for any additional
62
magic, it just contains the pickled main loop in '_pkl' file.
63
64
Let's do something more interesting:
65
66
>>> with open('main_loop.tar', 'wb') as dst:
67
...     dump(main_loop, dst,
68
...          parameters=main_loop.model.parameters)
69
>>> tarball = tarfile.open('main_loop.tar', 'r')
70
>>> tarball.getnames()
71
['_parameters', '_pkl']
72
73
As requested by specifying the `_parameters` argument, the parameters were
74
saved in a zip file.
75
76
>>> import numpy
77
>>> ps = numpy.load(tarball.extractfile(tarball.getmember('_parameters')))
78
>>> sorted(ps.keys()) # doctest: +ELLIPSIS
79
['|mlp|linear_0.W', '|mlp|linear_0.b', '|mlp|linear_1.W', '|mlp|lin...]
80
>>> ps.close()
81
82
The names for parameters are chosen intellegently to reflect their
83
position in the brick hierarchy, if they belong to bricks, and by
84
simply using the `.name` attribute, if they do not.
85
86
The loading of the main loop as a whole still works:
87
88
>>> with open('main_loop.tar', 'rb') as src:
89
...     main_loop_loaded = load(src)
90
>>> main_loop_loaded # doctest: +ELLIPSIS
91
<blocks.main_loop.MainLoop object at ...>
92
93
Additionally, this module provides convenience routine
94
:func:`load_parameters`:
95
96
>>> with open('main_loop.tar', 'rb') as src:
97
...     parameters = load_parameters(src)
98
>>> sorted(parameters.keys()) # doctest: +ELLIPSIS
99
['/mlp/linear_0.W', '/mlp/linear_0.b', '/mlp/linear_1.W', '/mlp/line...]
100
101
Loading parameters saved by :func:`dump` with :func:`load_parameters`
102
ensures that their heirarchical names are compatible with
103
:class:`~blocks.model.Model` and :class:`~blocks.select.Selector` classes.
104
105
TODO: Add information about :func:`add_to_dump`.
106
107
"""
108
import numpy
109
import os
110
import pickle
111
import shutil
112
import six
113
import tarfile
114
import tempfile
115
import warnings
116
import logging
117
118
from contextlib import closing
119
from pickle import HIGHEST_PROTOCOL
120
try:
121
    from pickle import DEFAULT_PROTOCOL
122
    from pickle import _Pickler
123
except ImportError:
124
    DEFAULT_PROTOCOL = HIGHEST_PROTOCOL
125
    from pickle import Pickler as _Pickler
126
from six.moves import cPickle
127
try:
128
    from theano.sandbox.cuda import cuda_ndarray
129
except ImportError:
130
    cuda_ndarray = None
131
132
from blocks.config import config
133
from blocks.filter import get_brick
134
from blocks.utils import change_recursion_limit
135
136
137
logger = logging.getLogger(__name__)
138
139
BRICK_DELIMITER = '|'
140
MAIN_MODULE_WARNING = """WARNING: Main loop depends on the function `{}` in \
141
`__main__` namespace.
142
143
Because of limitations to pickling, this means that you will not be able to \
144
resume your model outside of a namespace containing this function. In other \
145
words, you can only call `continue_training` from within this script."""
146
147
148
def dump(object_, file_, parameters=None, use_cpickle=False,
149
         protocol=DEFAULT_PROTOCOL, **kwargs):
150
    r"""Pickles an object, optionally saving its parameters separately.
151
152
    Parameters
153
    ----------
154
    object_ : object
155
        The object to pickle. If None, only the parameters passed to the
156
        `parameters` argument will be saved.
157
    file_ : file
158
        The destination for saving.
159
    parameters : list, optional
160
        Shared variables whose internal numpy arrays should be saved
161
        separately in the `_parameters` field of the tar file.
162
    pickle_object : bool
163
        If False, `object_` will not be serialized, only its parameters.
164
        This flag can be used when `object_` is not serializable, but one
165
        still want to save its parameters. Default: True
166
    use_cpickle : bool
167
        Use cPickle instead of pickle. Setting it to true will disable the
168
        warning message if you try to pickle objects from the main module,
169
        so be sure that there is no warning before turning this flag
170
        on. Default: False.
171
    protocol : int, optional
172
        The pickling protocol to use. Unlike Python's built-in pickle, the
173
        default is set to `2` instead of 0 for Python 2. The Python 3
174
        default (level 3) is maintained.
175
    \*\*kwargs
176
        Keyword arguments to be passed to `pickle.Pickler`.
177
178
    """
179
    if use_cpickle:
180
        pickler = cPickle.Pickler
181
    else:
182
        pickler = _PicklerWithWarning
183
    with closing(tarfile.TarFile(fileobj=file_, mode='w')) as tar_file:
184
        external_objects = {}
185
186
        def _save_parameters(f):
187
            renamer = _Renamer()
188
            named_parameters = {renamer(p): p for p in parameters}
189
            numpy.savez(f, **{n: p.get_value()
190
                              for n, p in named_parameters.items()})
191
            for n, p in named_parameters.items():
192
                array_ = p.container.storage[0]
193
                external_objects[id(array_)] = _mangle_parameter_name(
194
                    type(array_), n)
195
        if parameters:
196
            _taradd(_save_parameters, tar_file, '_parameters')
197
        if object_ is not None:
198
            save_object = _SaveObject(pickler, object_, external_objects,
199
                                      protocol, **kwargs)
200
            _taradd(save_object, tar_file, '_pkl')
201
202
203
def secure_dump(object_, path, dump_function=dump, **kwargs):
204
    r"""Robust serialization - does not corrupt your files when failed.
205
206
    Parameters
207
    ----------
208
    object_ : object
209
        The object to be saved to the disk.
210
    path : str
211
        The destination for saving.
212
    dump_function : function
213
        The function that is used to perform the serialization. Must take
214
        an object and file object as arguments. By default, :func:`dump` is
215
        used. An alternative would be :func:`pickle.dump`.
216
    \*\*kwargs
217
        Keyword arguments to be passed to `dump_function`.
218
219
    """
220
    try:
221
        logger.debug("Dumping object to a temporary file")
222
        with tempfile.NamedTemporaryFile(delete=False,
223
                                         dir=config.temp_dir) as temp:
224
            dump_function(object_, temp, **kwargs)
225
        logger.debug("Moving the temporary file")
226
        shutil.move(temp.name, path)
227
        logger.debug("Dump finished")
228
    except:
229
        if "temp" in locals():
230
            os.remove(temp.name)
231
        raise
232
233
234
def load(file_, name='_pkl', use_cpickle=False):
235
    """Loads an object saved using the `dump` function.
236
237
    By default, this function loads the object saved by the `dump`
238
    function. If some objects have been added to the archive using the
239
    `add_to_dump` function, then you can load them by passing their name
240
    to the `name` parameter.
241
242
    Parameters
243
    ----------
244
    file_ : file
245
        The file that contains the object to load.
246
    name : str
247
        Name of the object to load. Default is `_pkl`, meaning that it is
248
        the original object which have been dumped that is loaded.
249
    use_cpickle : bool
250
        Use cPickle instead of pickle. Default: False.
251
252
    Returns
253
    -------
254
    The object saved in file_.
255
256
    """
257
    file_.seek(0)  # To be able to read several objects in one file
258
    if use_cpickle:
259
        unpickler = cPickle.Unpickler
260
    else:
261
        unpickler = pickle.Unpickler
262
    with tarfile.open(fileobj=file_, mode='r') as tar_file:
263
        p = unpickler(
264
            tar_file.extractfile(tar_file.getmember(name)))
265
        if '_parameters' in tar_file.getnames():
266
            p.persistent_load = _PersistentLoad(tar_file)
267
        return p.load()
268
269
270
def load_parameters(file_):
271
    """Loads the parameter values saved by :func:`dump`.
272
273
    This functions loads the parameters that have been saved separately by
274
    :func:`dump`, ie the ones given to its parameter `parameters`.
275
276
    Parameters
277
    ----------
278
    file_ : file
279
        The source to load the parameters from.
280
281
    Returns
282
    -------
283
    A dictionary of (parameter name, numpy array) pairs.
284
285
    """
286
    with closing(_load_parameters_npzfile(file_)) as npz_file:
287
        return {name.replace(BRICK_DELIMITER, '/'): value
288
                for name, value in npz_file.items()}
289
290
291
def add_to_dump(object_, file_, name, parameters=None, use_cpickle=False,
292
                protocol=DEFAULT_PROTOCOL, **kwargs):
293
    r"""Pickles an object to an existing tar archive.
294
295
    This function allows to dump more objects to an existing archive. If
296
    the object you want to dump posesses the same set of shared variables
297
    as the object already dumped, you can pass them to the `parameters`
298
    argument, which will avoid them to be serialized a second time.
299
    However, it won't work if the shared variable you pass to the
300
    `parameters` argument are not already in the archive.
301
302
    Parameters
303
    ----------
304
    object_ : object
305
        The object to pickle.
306
    file_ : file
307
        The destination for saving, opened in read-write mode (`r+`).
308
    name : str
309
        The name of the object you are dumping. It will be used as a file
310
        name in the archive. '_pkl' and '_paramters' are reserved names
311
        and can't be used.
312
    parameters : list, optional
313
        Shared variables whose internal numpy arrays should be saved
314
        separately in the `_parameters` field of the tar file. Must be a
315
        subset of the parameters already in the archive.
316
    use_cpickle : bool
317
        Use cPickle instead of pickle. Setting it to true will disable the
318
        warning message if you try to pickle objects from the main module!
319
        Be sure that you don't have the warning before turning this flag
320
        on. Default: False.
321
    protocol : int, optional
322
        The pickling protocol to use. Unlike Python's built-in pickle, the
323
        default is set to `2` instead of 0 for Python 2. The Python 3
324
        default (level 3) is maintained.
325
    \*\*kwargs
326
        Keyword arguments to be passed to `pickle.Pickler`.
327
328
    """
329
    if name in ['_pkl', '_parameters']:
330
        raise ValueError("_pkl and _parameters are reserved names and can't"
331
                         " be used as name for your object.")
332
333
    external_parameters = {}
334
    if parameters is not None:
335
        renamer = _Renamer()
336
        named_parameters = {renamer(p): p for p in parameters}
337
        for n, p in named_parameters.items():
338
            array_ = p.container.storage[0]
339
            external_parameters[id(array_)] = _mangle_parameter_name(
340
                type(array_), n)
341
342
        # Check that the parameters are the same that the ones in the archive.
343
        file_.seek(0)  # To be able to read what is in the tar file already.
344
        with closing(tarfile.TarFile(fileobj=file_, mode='r')) as tar_file:
345
            if '_parameters' not in tar_file.getnames():
346
                raise ValueError("There is no parameters in the archive, so"
347
                                 " you can't use the argument parameters.")
348
            else:
349
                parameters = numpy.load(
350
                    tar_file.extractfile(tar_file.getmember('_parameters')))
351
                s1 = set(parameters.keys())
352
                s2 = [_unmangle_parameter_name(x)[1] for x in
353
                      external_parameters.values()]
354
                if not s1.issuperset(s2):
355
                    raise ValueError('The set of parameters is different'
356
                                     ' from the one in the archive.')
357
358
    if use_cpickle:
359
        pickler = cPickle.Pickler
360
    else:
361
        pickler = _PicklerWithWarning
362
    file_.seek(0)  # To be able to add new things in the tar file.
363
    with closing(tarfile.TarFile(fileobj=file_, mode='a')) as tar_file:
364
        save_object = _SaveObject(pickler, object_, external_parameters,
365
                                  protocol, **kwargs)
366
        _taradd(save_object, tar_file, name)
367
368
369
def continue_training(path):
370
    """Continues training using checkpoint.
371
372
    Parameters
373
    ----------
374
    path : str
375
        Path to checkpoint.
376
377
    Notes
378
    -----
379
    Python picklers can unpickle objects from global namespace only if
380
    they are present in namespace where unpickling happens. Often global
381
    functions are needed for mapping, filtering and other data stream
382
    operations. In a case if the main loop uses global objects and
383
    this function fails with a message like
384
    ```
385
    AttributeError: 'module' object has no attribute '...'
386
    ```
387
    it means that you need to import these objects.
388
389
    Examples
390
    --------
391
    This function can be used in two ways: in your script where a main
392
    loop defined or in a different script. For later options see Notes
393
    section.
394
395
    """
396
    with change_recursion_limit(config.recursion_limit):
397
        with open(path, "rb") as f:
398
            main_loop = load(f)
399
    main_loop.run()
400
401
402
def dump_and_add_to_dump(object_, file_, parameters=None, to_add=None,
403
                         use_cpickle=False, protocol=DEFAULT_PROTOCOL,
404
                         **kwargs):
405
    r"""Calls both `dump` and `add_to_dump` to serialze several objects.
406
407
    This function is used to serialize several at the same time, using
408
    persistent ID. Its main advantage is that it can be used with
409
    `secure_dump`.
410
411
    Parameters
412
    ----------
413
    object_ : object
414
        The object to pickle. If None, only the parameters passed to the
415
        `parameters` argument will be saved.
416
    file_ : file
417
        The destination for saving.
418
    parameters : list, optional
419
        Shared variables whose internal numpy arrays should be saved
420
        separately in the `_parameters` field of the tar file.
421
    to_add : dict of objects
422
        A {'name': object} dictionnary of additional objects to save in
423
        the tar archive. Its keys will be used as name in the tar file.
424
    use_cpickle : bool
425
        Use cPickle instead of pickle. Setting it to true will disable the
426
        warning message if you try to pickle objects from the main module,
427
        so be sure that there is no warning before turning this flag
428
        on. Default: False.
429
    protocol : int, optional
430
        The pickling protocol to use. Unlike Python's built-in pickle, the
431
        default is set to `2` instead of 0 for Python 2. The Python 3
432
        default (level 3) is maintained.
433
    \*\*kwargs
434
        Keyword arguments to be passed to `pickle.Pickler`.
435
436
    """
437
    dump(object_, file_, parameters=parameters, use_cpickle=use_cpickle,
438
         protocol=protocol, **kwargs)
439
    if to_add is not None:
440
        for name, obj in six.iteritems(to_add):
441
            add_to_dump(obj, file_, name, parameters=parameters,
442
                        use_cpickle=use_cpickle, protocol=protocol, **kwargs)
443
444
445
class _PicklerWithWarning(_Pickler):
446
    """Pickler that adds a warning message.
447
448
    Adds a warning message if we try to save an object referenced in the
449
    main module.
450
451
    """
452
    dispatch = _Pickler.dispatch.copy()
453
454
    def save_global(self, obj, name=None, **kwargs):
455
        module = getattr(obj, '__module__', None)
456
        if module == '__main__':
457
            warnings.warn(
458
                MAIN_MODULE_WARNING.format(kwargs.get('name', obj.__name__))
459
            )
460
        _Pickler.save_global(self, obj, name=name, **kwargs)
461
462
    dispatch[six.types.FunctionType] = save_global
463
    if six.PY2:
464
        dispatch[six.types.ClassType] = save_global
0 ignored issues
show
Bug introduced by
The Module types does not seem to have a member named ClassType.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
465
        dispatch[six.types.BuiltinFunctionType] = save_global
466
        dispatch[six.types.TypeType] = save_global
0 ignored issues
show
Bug introduced by
The Module types does not seem to have a member named TypeType.

This check looks for calls to members that are non-existent. These calls will fail.

The member could have been renamed or removed.

Loading history...
467
468
469
class _SaveObject(object):
470
    r"""Saves an object using Persistent ID.
471
472
    Parameters
473
    ----------
474
    pickler : object
475
        The pickler to use
476
    object_ : object
477
        The object to pickle.
478
    external_objects : dict of object
479
        The external objects to save using persistent id.
480
    protocol : int, optional
481
        The pickling protocol to use.
482
    \*\*kwargs
483
        Keyword arguments to be passed to `pickle.Pickler`.
484
485
    """
486
    def __init__(self, pickler, object_, external_objects, protocol, **kwargs):
487
        self.pickler = pickler
488
        self.object_ = object_
489
        self.external_objects = external_objects
490
        self.protocol = protocol
491
        self.kwargs = kwargs
492
493
    def __call__(self, f):
494
        p = self.pickler(f, protocol=self.protocol, **self.kwargs)
495
        p.persistent_id = _PersistentID(self.external_objects)
496
        p.dump(self.object_)
497
498
499
class _Renamer(object):
500
    """Returns a new name for the given parameter.
501
502
    It maintains a list of names already used to avoid naming
503
    collisions. It also provides names for variables without
504
    names.
505
506
    Attributes
507
    ----------
508
    used_names : set
509
        The set of names already taken.
510
    default_name : str
511
        The name to use if a parameter doesn't have a name. Default:
512
        'parameter'.
513
514
    """
515
    def __init__(self):
516
        self.used_names = set()
517
        self.default_name = 'parameter'
518
519
    def __call__(self, parameter):
520
        # Standard Blocks parameter
521
        if get_brick(parameter) is not None:
522
            name = '{}.{}'.format(
523
                BRICK_DELIMITER.join(
524
                    [""] + [brick.name for brick in
525
                            get_brick(parameter).get_unique_path()]),
526
                parameter.name)
527
        # Shared variables with tag.name
528
        elif hasattr(parameter.tag, 'name'):
529
            name = parameter.tag.name
530
        # Standard shared variable
531
        elif parameter.name is not None:
532
            name = parameter.name
533
        # Variables without names
534
        else:
535
            name = self.default_name
536
        # Handle naming collisions
537
        if name in self.used_names:
538
            i = 2
539
            new_name = '_'.join([name, str(i)])
540
            while new_name in self.used_names:
541
                i += 1
542
                new_name = '_'.join([name, str(i)])
543
            name = new_name
544
        self.used_names.add(name)
545
        return name
546
547
548
_ARRAY_TYPE_MAP = {numpy.ndarray: 'numpy_ndarray'}
549
_INVERSE_ARRAY_TYPE_MAP = {'numpy_ndarray': numpy.array}
550
if cuda_ndarray:
551
    _ARRAY_TYPE_MAP[cuda_ndarray.cuda_ndarray.CudaNdarray] = 'cuda_ndarray'
552
    _INVERSE_ARRAY_TYPE_MAP['cuda_ndarray'] = \
553
        cuda_ndarray.cuda_ndarray.CudaNdarray
554
555
556
class _PersistentID(object):
557
    """Returns persistent identifiers for objects saved separately."""
558
    def __init__(self, external_objects):
559
        self.external_objects = external_objects
560
561
    def __call__(self, object_):
562
        return self.external_objects.get(id(object_))
563
564
565
class _PersistentLoad(object):
566
    """Loads object saved using a PersistentID mechanism."""
567
    def __init__(self, tar_file):
568
        self.tar_file = tar_file
569
        if '_parameters' in tar_file.getnames():
570
            self.parameters = numpy.load(
571
                tar_file.extractfile(tar_file.getmember('_parameters')))
572
        self._cache = {}
573
574
    def __call__(self, id_):
575
        # As we empirically found out, this method can be called multiple
576
        # times  with the same id_. That's why we need a cache here to
577
        # avoid creating the same object more than once.
578
        if id_ not in self._cache:
579
            components = _unmangle_parameter_name(id_)
580
            self._cache[id_] = components[0](self.parameters[components[1]])
581
        return self._cache[id_]
582
583
584
def _mangle_parameter_name(type_, name):
585
    return '#{}.{}'.format(_ARRAY_TYPE_MAP[type_], name)
586
587
588
def _unmangle_parameter_name(mangled_name):
589
    if mangled_name.startswith('#'):
590
        type_, name = mangled_name[1:].split('.', 1)
591
        return _INVERSE_ARRAY_TYPE_MAP[type_], name
592
593
594
def _taradd(func, tar_file, name):
595
    """Adds elements dumped by the function `func` to a tar_file.
596
597
    This functions first calls the function `func` and add the file that
598
    `func` dumps to the achive `tar_file`, under the name `name`.
599
600
    Parameters
601
    ----------
602
    func : function
603
        The dumping function.
604
    tar_file : file
605
        The archive that we are filling.
606
    name : str
607
        The name of the dumped file in the archive.
608
609
    """
610
    with tempfile.NamedTemporaryFile('wb', delete=False) as temp_file:
611
        func(temp_file)
612
        temp_file.close()
613
        tar_file.add(temp_file.name, arcname=name)
614
    if os.path.isfile(temp_file.name):
615
        os.remove(temp_file.name)
616
617
618
def _load_parameters_npzfile(file_):
619
    """Loads parameters from a .npz file in a tar archive."""
620
    with tarfile.open(fileobj=file_, mode='r') as tar_file:
621
        return numpy.load(
622
            tar_file.extractfile(tar_file.getmember('_parameters')))
623