Completed
Push — master ( 3f002e...6b68ea )
by David
56:24
created

blocks/extensions/saveload.py (12 issues)

Labels
Severity
1
"""Extensions for saving and loading the state of a training process."""
0 ignored issues
show
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks.base -> blocks.graph -> blocks.graph.bn -> blocks.filter).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.wrappers -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.misc -> blocks.bricks.parallel -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.misc -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.architectures -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.misc -> blocks.bricks.recurrent.base -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
2
import os.path
3
import logging
4
5
from blocks.extensions import SimpleExtension, TrainingExtension
6
from blocks.utils import reraise_as
7
from blocks.serialization import (secure_dump, load, dump_and_add_to_dump,
8
                                  load_parameters)
9
10
logger = logging.getLogger(__name__)
11
12
LOADED_FROM = "loaded_from"
13
SAVED_TO = "saved_to"
14
15
16
class Checkpoint(SimpleExtension):
17
    """Saves a pickled version of the main loop to the disk.
18
19
    The pickled main loop can be later reloaded and training can be
20
    resumed.
21
22
    Makes a `SAVED_TO` record in the log with the serialization destination
23
    in the case of success and ``None`` in the case of failure. The
24
    value of the record is a tuple of paths to which saving was done
25
    (there can be more than one if the user added a condition
26
    with an argument, see :meth:`do` docs).
27
28
    Parameters
29
    ----------
30
    path : str
31
        The destination path for pickling.
32
    parameters : list, optional
33
        The parameters to save separately. If None, the parameters from
34
        the model (main_loop.model.parameters) are saved.
35
    save_separately : list of str, optional
36
        The list of the main loop's attributes to be saved (copied)
37
        in a separate file in the tar archive. It may be used for example
38
        to save the log separetely. The name of the attribute will be used
39
        as name in the tar file.
40
    save_main_loop : bool
41
        Choose whether to save the main loop or not. This can be useful
42
        for example if you are only interested in saving the parameters,
43
        but not the whole main loop. Defaults to `True`.
44
    use_cpickle : bool
45
        See documentation of :func:`~blocks.serialization.dump`.
46
47
    Notes
48
    -----
49
    Using pickling for saving the whole main loop object comes with
50
    certain limitations:
51
52
    * Theano computation graphs build in the GPU-mode
53
      (`theano.config.device == "gpu"`) can not be used in the usual mode
54
      (and vice-versa). Therefore using this extension binds you to using
55
      only one kind of device.
56
57
58
    """
59
    def __init__(self, path, parameters=None, save_separately=None,
60
                 save_main_loop=True, use_cpickle=False, **kwargs):
61
        kwargs.setdefault("after_training", True)
62
        super(Checkpoint, self).__init__(**kwargs)
63
        self.path = path
64
        self.parameters = parameters
65
        self.save_separately = save_separately
66
        self.save_main_loop = save_main_loop
67
        self.use_cpickle = use_cpickle
68
69
    def do(self, callback_name, *args):
70
        """Pickle the main loop object to the disk.
71
72
        If `*args` contain an argument from user, it is treated as
73
        saving path to be used instead of the one given at the
74
        construction stage.
75
76
        """
77
        logger.info("Checkpointing has started")
78
        _, from_user = self.parse_args(callback_name, args)
79
        try:
80
            path = self.path
81
            if from_user:
82
                path, = from_user
83
            to_add = None
84
            if self.save_separately:
85
                to_add = {attr: getattr(self.main_loop, attr) for attr in
86
                          self.save_separately}
87
            if self.parameters is None:
88
                if hasattr(self.main_loop, 'model'):
89
                    self.parameters = self.main_loop.model.parameters
90
            object_ = None
91
            if self.save_main_loop:
92
                object_ = self.main_loop
93
            secure_dump(object_, path,
94
                        dump_function=dump_and_add_to_dump,
95
                        parameters=self.parameters,
96
                        to_add=to_add,
97
                        use_cpickle=self.use_cpickle)
98
        except Exception:
99
            path = None
100
            raise
101
        finally:
102
            already_saved_to = self.main_loop.log.current_row.get(SAVED_TO, ())
103
            self.main_loop.log.current_row[SAVED_TO] = (already_saved_to +
104
                                                        (path,))
105
            logger.info("Checkpointing has finished")
106
107
108
class Load(TrainingExtension):
109
    """Loads a saved checkpoint into the main loop.
110
111
    Makes a `LOADED_FROM` record in the log with the dump path.
112
113
    Parameters
114
    ----------
115
    path : str
116
        The path to the folder with dump.
117
    load_iteration_state : bool
118
        If `True`, load the iteration state. This can be useful when your
119
        model has very long epochs, and you want to resume when you were in
120
        the middle of one. Defaults to `False`.
121
    load_log : bool
122
        If `True`, load the old log and continue logging from there.
123
        Convenient because you end up with a single log of the entire
124
        training history. Defaults to `False`.
125
126
    Notes
127
    -----
128
    Requires the model to be created entirely using bricks, with a unique
129
    path/name for each brick, so that the parameters can be matched to
130
    their values.
131
132
    In order to load the iteration state and the log, the saved model needs
133
    to be unpickled. Note that resuming training this way is still not
134
    entirely seamless because e.g. extensions will not be reloaded.
135
136
    """
137
    def __init__(self, path, load_iteration_state=False, load_log=False,
138
                 **kwargs):
139
        super(Load, self).__init__(**kwargs)
140
        self.path = path
141
        self.load_iteration_state = load_iteration_state
142
        self.load_log = load_log
143
144
    def load_to(self, main_loop):
145
        with open(self.path, "rb") as source:
146
            main_loop.model.set_parameter_values(load_parameters(source))
147
            if self.load_iteration_state or self.load_log:
148
                loaded_main_loop = load(source)
149
                if self.load_log:
150
                    main_loop.log = loaded_main_loop.log
151
                if self.load_iteration_state:
152
                    main_loop.iteration_state = \
153
                        loaded_main_loop.iteration_state
154
155
    def before_training(self):
156
        if not os.path.exists(self.path):
157
            logger.warning("No dump found")
158
            return
159
        logger.info("loading model from {}".format(self.path))
160
        try:
161
            self.load_to(self.main_loop)
162
            self.main_loop.log.current_row[LOADED_FROM] = self.path
163
        except Exception:
164
            reraise_as("Failed to load the state")
165