1 | """Extensions for saving and loading the state of a training process.""" |
||
2 | import os.path |
||
3 | import logging |
||
4 | |||
5 | from blocks.extensions import SimpleExtension |
||
6 | from blocks.utils import reraise_as |
||
7 | from blocks.serialization import (secure_dump, load, dump_and_add_to_dump, |
||
8 | load_parameters) |
||
9 | |||
10 | logger = logging.getLogger(__name__) |
||
11 | |||
12 | LOADED_FROM = "loaded_from" |
||
13 | SAVED_TO = "saved_to" |
||
14 | |||
15 | |||
16 | class Checkpoint(SimpleExtension): |
||
17 | """Saves a pickled version of the main loop to the disk. |
||
18 | |||
19 | The pickled main loop can be later reloaded and training can be |
||
20 | resumed. |
||
21 | |||
22 | Makes a `SAVED_TO` record in the log with the serialization destination |
||
23 | in the case of success and ``None`` in the case of failure. The |
||
24 | value of the record is a tuple of paths to which saving was done |
||
25 | (there can be more than one if the user added a condition |
||
26 | with an argument, see :meth:`do` docs). |
||
27 | |||
28 | Parameters |
||
29 | ---------- |
||
30 | path : str |
||
31 | The destination path for pickling. |
||
32 | parameters : list, optional |
||
33 | The parameters to save separately. If None, the parameters from |
||
34 | the model (main_loop.model.parameters) are saved. |
||
35 | save_separately : list of str, optional |
||
36 | The list of the main loop's attributes to be saved (copied) |
||
37 | in a separate file in the tar archive. It may be used for example |
||
38 | to save the log separetely. The name of the attribute will be used |
||
39 | as name in the tar file. |
||
40 | save_main_loop : bool |
||
41 | Choose whether to save the main loop or not. This can be useful |
||
42 | for example if you are only interested in saving the parameters, |
||
43 | but not the whole main loop. Defaults to `True`. |
||
44 | use_cpickle : bool |
||
45 | See documentation of :func:`~blocks.serialization.dump`. |
||
46 | |||
47 | Notes |
||
48 | ----- |
||
49 | Using pickling for saving the whole main loop object comes with |
||
50 | certain limitations: |
||
51 | |||
52 | * Theano computation graphs build in the GPU-mode |
||
53 | (`theano.config.device == "gpu"`) can not be used in the usual mode |
||
54 | (and vice-versa). Therefore using this extension binds you to using |
||
55 | only one kind of device. |
||
56 | |||
57 | |||
58 | """ |
||
59 | def __init__(self, path, parameters=None, save_separately=None, |
||
60 | save_main_loop=True, use_cpickle=False, **kwargs): |
||
61 | kwargs.setdefault("after_training", True) |
||
62 | super(Checkpoint, self).__init__(**kwargs) |
||
63 | self.path = path |
||
64 | self.parameters = parameters |
||
65 | self.save_separately = save_separately |
||
66 | self.save_main_loop = save_main_loop |
||
67 | self.use_cpickle = use_cpickle |
||
68 | |||
69 | def do(self, callback_name, *args): |
||
70 | """Pickle the main loop object to the disk. |
||
71 | |||
72 | If `*args` contain an argument from user, it is treated as |
||
73 | saving path to be used instead of the one given at the |
||
74 | construction stage. |
||
75 | |||
76 | """ |
||
77 | logger.info("Checkpointing has started") |
||
78 | _, from_user = self.parse_args(callback_name, args) |
||
79 | try: |
||
80 | path = self.path |
||
81 | if from_user: |
||
82 | path, = from_user |
||
83 | to_add = None |
||
84 | if self.save_separately: |
||
85 | to_add = {attr: getattr(self.main_loop, attr) for attr in |
||
86 | self.save_separately} |
||
87 | if self.parameters is None: |
||
88 | if hasattr(self.main_loop, 'model'): |
||
89 | self.parameters = self.main_loop.model.parameters |
||
90 | object_ = None |
||
91 | if self.save_main_loop: |
||
92 | object_ = self.main_loop |
||
93 | secure_dump(object_, path, |
||
94 | dump_function=dump_and_add_to_dump, |
||
95 | parameters=self.parameters, |
||
96 | to_add=to_add, |
||
97 | use_cpickle=self.use_cpickle) |
||
98 | except Exception: |
||
99 | path = None |
||
100 | raise |
||
101 | finally: |
||
102 | already_saved_to = self.main_loop.log.current_row.get(SAVED_TO, ()) |
||
103 | self.main_loop.log.current_row[SAVED_TO] = (already_saved_to + |
||
104 | (path,)) |
||
105 | logger.info("Checkpointing has finished") |
||
106 | |||
107 | |||
108 | class Load(SimpleExtension): |
||
109 | """Loads a saved checkpoint into the main loop. |
||
110 | |||
111 | Makes a `LOADED_FROM` record in the log with the dump path. |
||
112 | |||
113 | Parameters |
||
114 | ---------- |
||
115 | path : str |
||
116 | The path to the folder with dump. |
||
117 | load_iteration_state : bool |
||
118 | If `True`, load the iteration state. This can be useful when your |
||
119 | model has very long epochs, and you want to resume when you were in |
||
120 | the middle of one. Defaults to `False`. |
||
121 | load_log : bool |
||
122 | If `True`, load the old log and continue logging from there. |
||
123 | Convenient because you end up with a single log of the entire |
||
124 | training history. Defaults to `False`. |
||
125 | |||
126 | Notes |
||
127 | ----- |
||
128 | Requires the model to be created entirely using bricks, with a unique |
||
129 | path/name for each brick, so that the parameters can be matched to |
||
130 | their values. |
||
131 | |||
132 | In order to load the iteration state and the log, the saved model needs |
||
133 | to be unpickled. Note that resuming training this way is still not |
||
134 | entirely seamless because e.g. extensions will not be reloaded. |
||
135 | |||
136 | """ |
||
137 | def __init__(self, path, load_iteration_state=False, load_log=False, |
||
138 | **kwargs): |
||
139 | kwargs.setdefault("before_training", True) |
||
140 | super(Load, self).__init__(**kwargs) |
||
141 | self.path = path |
||
142 | self.load_iteration_state = load_iteration_state |
||
143 | self.load_log = load_log |
||
144 | |||
145 | def load_to(self, main_loop): |
||
146 | with open(self.path, "rb") as source: |
||
147 | main_loop.model.set_parameter_values(load_parameters(source)) |
||
148 | if self.load_iteration_state or self.load_log: |
||
149 | loaded_main_loop = load(source) |
||
150 | if self.load_log: |
||
151 | main_loop.log = loaded_main_loop.log |
||
152 | if self.load_iteration_state: |
||
153 | main_loop.iteration_state = \ |
||
154 | loaded_main_loop.iteration_state |
||
155 | |||
156 | def do(self, *args, **kwargs): |
||
157 | if not os.path.exists(self.path): |
||
158 | logger.warning("No dump found") |
||
159 | return |
||
160 | logger.info("loading model from {}".format(self.path)) |
||
161 | try: |
||
162 | self.load_to(self.main_loop) |
||
163 | self.main_loop.log.current_row[LOADED_FROM] = self.path |
||
164 | except Exception: |
||
0 ignored issues
–
show
|
|||
165 | reraise_as("Failed to load the state") |
||
166 |
Generally, you would want to handle very specific errors in the exception handler. This ensure that you do not hide other types of errors which should be fixed.
So, unless you specifically plan to handle any error, consider adding a more specific exception.