Completed
Push — master ( 510775...568e7a )
by Dmitry
01:47
created

save_separately_filenames()   A

Complexity

Conditions 2

Size

Total Lines 18

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 2
dl 0
loc 18
rs 9.4285
1
"""Extensions for saving and loading the state of a training process."""
2
import os.path
3
import logging
4
5
from blocks.extensions import SimpleExtension, TrainingExtension
6
from blocks.utils import reraise_as
7
from blocks.serialization import (secure_dump, load, dump_and_add_to_dump,
8
                                  load_parameters)
9
10
logger = logging.getLogger(__name__)
11
12
LOADED_FROM = "loaded_from"
13
SAVED_TO = "saved_to"
14
15
16
class Checkpoint(SimpleExtension):
17
    """Saves a pickled version of the main loop to the disk.
18
19
    The pickled main loop can be later reloaded and training can be
20
    resumed.
21
22
    Makes a `SAVED_TO` record in the log with the serialization destination
23
    in the case of success and ``None`` in the case of failure. The
24
    value of the record is a tuple of paths to which saving was done
25
    (there can be more than one if the user added a condition
26
    with an argument, see :meth:`do` docs).
27
28
    Parameters
29
    ----------
30
    path : str
31
        The destination path for pickling.
32
    parameters : list, optional
33
        The parameters to save separately. If None, the parameters from
34
        the model (main_loop.model.parameters) are saved.
35
    save_separately : list of str, optional
36
        The list of the main loop's attributes to be saved (copied)
37
        in a separate file in the tar archive. It may be used for example
38
        to save the log separetely. The name of the attribute will be used
39
        as name in the tar file.
40
    save_main_loop : bool
41
        Choose whether to save the main loop or not. This can be useful
42
        for example if you are only interested in saving the parameters,
43
        but not the whole main loop. Defaults to `True`.
44
    use_cpickle : bool
45
        See documentation of :func:`~blocks.serialization.dump`.
46
47
    Notes
48
    -----
49
    Using pickling for saving the whole main loop object comes with
50
    certain limitations:
51
52
    * Theano computation graphs build in the GPU-mode
53
      (`theano.config.device == "gpu"`) can not be used in the usual mode
54
      (and vice-versa). Therefore using this extension binds you to using
55
      only one kind of device.
56
57
58
    """
59
    def __init__(self, path, parameters=None, save_separately=None,
60
                 save_main_loop=True, use_cpickle=False, **kwargs):
61
        kwargs.setdefault("after_training", True)
62
        super(Checkpoint, self).__init__(**kwargs)
63
        self.path = path
64
        self.parameters = parameters
65
        self.save_separately = save_separately
66
        self.save_main_loop = save_main_loop
67
        self.use_cpickle = use_cpickle
68
69
    def do(self, callback_name, *args):
70
        """Pickle the main loop object to the disk.
71
72
        If `*args` contain an argument from user, it is treated as
73
        saving path to be used instead of the one given at the
74
        construction stage.
75
76
        """
77
        _, from_user = self.parse_args(callback_name, args)
78
        try:
79
            path = self.path
80
            if from_user:
81
                path, = from_user
82
            to_add = None
83
            if self.save_separately:
84
                to_add = {attr: getattr(self.main_loop, attr) for attr in
85
                          self.save_separately}
86
            if self.parameters is None:
87
                if hasattr(self.main_loop, 'model'):
88
                    self.parameters = self.main_loop.model.parameters
89
            object_ = None
90
            if self.save_main_loop:
91
                object_ = self.main_loop
92
            secure_dump(object_, path,
93
                        dump_function=dump_and_add_to_dump,
94
                        parameters=self.parameters,
95
                        to_add=to_add,
96
                        use_cpickle=self.use_cpickle)
97
        except Exception:
98
            path = None
99
            raise
100
        finally:
101
            already_saved_to = self.main_loop.log.current_row.get(SAVED_TO, ())
102
            self.main_loop.log.current_row[SAVED_TO] = (already_saved_to +
103
                                                        (path,))
104
105
106
class Load(TrainingExtension):
107
    """Loads a saved checkpoint into the main loop.
108
109
    Makes a `LOADED_FROM` record in the log with the dump path.
110
111
    Parameters
112
    ----------
113
    path : str
114
        The path to the folder with dump.
115
    load_iteration_state : bool
116
        If `True`, load the iteration state. This can be useful when your
117
        model has very long epochs, and you want to resume when you were in
118
        the middle of one. Defaults to `False`.
119
    load_log : bool
120
        If `True`, load the old log and continue logging from there.
121
        Convenient because you end up with a single log of the entire
122
        training history. Defaults to `False`.
123
124
    Notes
125
    -----
126
    Requires the model to be created entirely using bricks, with a unique
127
    path/name for each brick, so that the parameters can be matched to
128
    their values.
129
130
    In order to load the iteration state and the log, the saved model needs
131
    to be unpickled. Note that resuming training this way is still not
132
    entirely seamless because e.g. extensions will not be reloaded.
133
134
    """
135
    def __init__(self, path, load_iteration_state=False, load_log=False,
136
                 **kwargs):
137
        super(Load, self).__init__(**kwargs)
138
        self.path = path
139
        self.load_iteration_state = load_iteration_state
140
        self.load_log = load_log
141
142
    def load_to(self, main_loop):
143
        with open(self.path, "rb") as source:
144
            main_loop.model.set_parameter_values(load_parameters(source))
145
            if self.load_iteration_state or self.load_log:
146
                loaded_main_loop = load(source)
147
                if self.load_log:
148
                    main_loop.log = loaded_main_loop.log
149
                if self.load_iteration_state:
150
                    main_loop.iteration_state = \
151
                        loaded_main_loop.iteration_state
152
153
    def before_training(self):
154
        if not os.path.exists(self.path):
155
            logger.warning("No dump found")
156
            return
157
        logger.info("loading model from {}".format(self.path))
158
        try:
159
            self.load_to(self.main_loop)
160
            self.main_loop.log.current_row[LOADED_FROM] = self.path
161
        except Exception:
0 ignored issues
show
Best Practice introduced by
Catching very general exceptions such as Exception is usually not recommended.

Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.

So, unless you specifically plan to handle any error, consider adding a more specific exception.

Loading history...
162
            reraise_as("Failed to load the state")
163