|
1
|
|
|
"""Blocks native serialization - tar files with pickles and numpy arrays. |
|
2
|
|
|
|
|
3
|
|
|
This module provides :func:`load` and :func:`dump` functions that can serve |
|
4
|
|
|
as drop-in replacement for the respective functions from the standard |
|
5
|
|
|
:mod:`pickle` module. The main differences between them and the standard |
|
6
|
|
|
ones are: |
|
7
|
|
|
|
|
8
|
|
|
- The dump is physically a tarball, in which the pickle is stored |
|
9
|
|
|
as '_pkl' file. |
|
10
|
|
|
|
|
11
|
|
|
- A special file '_parameters' in the tarball can contain the data |
|
12
|
|
|
of a selected set of Theano shared variables. This data is |
|
13
|
|
|
referenced from `_pkl` using persistent id mechanism, which means |
|
14
|
|
|
that no duplication takes place. The goal here is to save the values |
|
15
|
|
|
of the parameters (this is what these shared variables are in most |
|
16
|
|
|
cases) in the most robust way possible. The actual format for |
|
17
|
|
|
'_parameters' file is the one used by :func:`numpy.savez`, i.e. a zip |
|
18
|
|
|
file of numpy arrays. |
|
19
|
|
|
|
|
20
|
|
|
- More objects can be dumped in the archive using the `add_to_dump` |
|
21
|
|
|
function. If the object has the same parameters as the one alread |
|
22
|
|
|
dumped, then you can avoid to dump those parameters thank to the |
|
23
|
|
|
persistent id mechanism. |
|
24
|
|
|
|
|
25
|
|
|
- The :func:`dump` strives to catch situations when the user tries |
|
26
|
|
|
to pickle a function or a class not defined in the global namespace |
|
27
|
|
|
and give a meaningful warning. |
|
28
|
|
|
|
|
29
|
|
|
If briefly, this module proposes a dumping mechanism which allows for |
|
30
|
|
|
greater robustness and persistency than standard pickling. |
|
31
|
|
|
|
|
32
|
|
|
Examples |
|
33
|
|
|
-------- |
|
34
|
|
|
Consider a standard main loop (without an algorithm and a data stream |
|
35
|
|
|
for brevity) |
|
36
|
|
|
|
|
37
|
|
|
>>> from theano import tensor |
|
38
|
|
|
>>> from blocks.main_loop import MainLoop |
|
39
|
|
|
>>> from blocks.bricks import MLP, Tanh, Softmax |
|
40
|
|
|
>>> from blocks.model import Model |
|
41
|
|
|
>>> mlp = MLP([Tanh(), None], [784, 10, 10]) |
|
42
|
|
|
>>> x = tensor.matrix('features') |
|
43
|
|
|
>>> y = tensor.lmatrix('targets') |
|
44
|
|
|
>>> cost = Softmax().categorical_cross_entropy( |
|
45
|
|
|
... y.flatten(), mlp.apply(tensor.flatten(x, outdim=2))) |
|
46
|
|
|
>>> main_loop = MainLoop(None, None, model=Model(cost)) |
|
47
|
|
|
|
|
48
|
|
|
Let's see how the main loop is dumped by :func:`dump` |
|
49
|
|
|
|
|
50
|
|
|
>>> from blocks.serialization import dump, load |
|
51
|
|
|
>>> import tarfile |
|
52
|
|
|
>>> with open('main_loop.tar', 'wb') as dst: |
|
53
|
|
|
... dump(main_loop, dst) |
|
54
|
|
|
>>> tarball = tarfile.open('main_loop.tar', 'r') |
|
55
|
|
|
>>> tarball # doctest: +ELLIPSIS |
|
56
|
|
|
<tarfile.TarFile object at ...> |
|
57
|
|
|
>>> tarball.getnames() |
|
58
|
|
|
['_pkl'] |
|
59
|
|
|
>>> tarball.close() |
|
60
|
|
|
|
|
61
|
|
|
As promised, the dump is a tarball. Since we did not ask for any additional |
|
62
|
|
|
magic, it just contains the pickled main loop in '_pkl' file. |
|
63
|
|
|
|
|
64
|
|
|
Let's do something more interesting: |
|
65
|
|
|
|
|
66
|
|
|
>>> with open('main_loop.tar', 'wb') as dst: |
|
67
|
|
|
... dump(main_loop, dst, |
|
68
|
|
|
... parameters=main_loop.model.parameters) |
|
69
|
|
|
>>> tarball = tarfile.open('main_loop.tar', 'r') |
|
70
|
|
|
>>> tarball.getnames() |
|
71
|
|
|
['_parameters', '_pkl'] |
|
72
|
|
|
|
|
73
|
|
|
As requested by specifying the `_parameters` argument, the parameters were |
|
74
|
|
|
saved in a zip file. |
|
75
|
|
|
|
|
76
|
|
|
>>> import numpy |
|
77
|
|
|
>>> ps = numpy.load(tarball.extractfile(tarball.getmember('_parameters'))) |
|
78
|
|
|
>>> sorted(ps.keys()) # doctest: +ELLIPSIS |
|
79
|
|
|
['|mlp|linear_0.W', '|mlp|linear_0.b', '|mlp|linear_1.W', '|mlp|lin...] |
|
80
|
|
|
>>> ps.close() |
|
81
|
|
|
|
|
82
|
|
|
The names for parameters are chosen intellegently to reflect their |
|
83
|
|
|
position in the brick hierarchy, if they belong to bricks, and by |
|
84
|
|
|
simply using the `.name` attribute, if they do not. |
|
85
|
|
|
|
|
86
|
|
|
The loading of the main loop as a whole still works: |
|
87
|
|
|
|
|
88
|
|
|
>>> with open('main_loop.tar', 'rb') as src: |
|
89
|
|
|
... main_loop_loaded = load(src) |
|
90
|
|
|
>>> main_loop_loaded # doctest: +ELLIPSIS |
|
91
|
|
|
<blocks.main_loop.MainLoop object at ...> |
|
92
|
|
|
|
|
93
|
|
|
Additionally, this module provides convenience routine |
|
94
|
|
|
:func:`load_parameters`: |
|
95
|
|
|
|
|
96
|
|
|
>>> with open('main_loop.tar', 'rb') as src: |
|
97
|
|
|
... parameters = load_parameters(src) |
|
98
|
|
|
>>> sorted(parameters.keys()) # doctest: +ELLIPSIS |
|
99
|
|
|
['/mlp/linear_0.W', '/mlp/linear_0.b', '/mlp/linear_1.W', '/mlp/line...] |
|
100
|
|
|
|
|
101
|
|
|
Loading parameters saved by :func:`dump` with :func:`load_parameters` |
|
102
|
|
|
ensures that their heirarchical names are compatible with |
|
103
|
|
|
:class:`~blocks.model.Model` and :class:`~blocks.select.Selector` classes. |
|
104
|
|
|
|
|
105
|
|
|
TODO: Add information about :func:`add_to_dump`. |
|
106
|
|
|
|
|
107
|
|
|
""" |
|
108
|
|
|
import numpy |
|
109
|
|
|
import os |
|
110
|
|
|
import pickle |
|
111
|
|
|
import shutil |
|
112
|
|
|
import six |
|
113
|
|
|
import tarfile |
|
114
|
|
|
import tempfile |
|
115
|
|
|
import warnings |
|
116
|
|
|
|
|
117
|
|
|
from contextlib import closing |
|
118
|
|
|
from pickle import HIGHEST_PROTOCOL |
|
119
|
|
|
try: |
|
120
|
|
|
from pickle import DEFAULT_PROTOCOL |
|
121
|
|
|
from pickle import _Pickler |
|
122
|
|
|
except ImportError: |
|
123
|
|
|
DEFAULT_PROTOCOL = HIGHEST_PROTOCOL |
|
124
|
|
|
from pickle import Pickler as _Pickler |
|
125
|
|
|
from six.moves import cPickle |
|
126
|
|
|
try: |
|
127
|
|
|
from theano.sandbox.cuda import cuda_ndarray |
|
128
|
|
|
except ImportError: |
|
129
|
|
|
cuda_ndarray = None |
|
130
|
|
|
|
|
131
|
|
|
from blocks.config import config |
|
132
|
|
|
from blocks.filter import get_brick |
|
133
|
|
|
from blocks.utils import change_recursion_limit |
|
134
|
|
|
|
|
135
|
|
|
|
|
136
|
|
|
BRICK_DELIMITER = '|' |
|
137
|
|
|
MAIN_MODULE_WARNING = """WARNING: Main loop depends on the function `{}` in \ |
|
138
|
|
|
`__main__` namespace. |
|
139
|
|
|
|
|
140
|
|
|
Because of limitations to pickling, this means that you will not be able to \ |
|
141
|
|
|
resume your model outside of a namespace containing this function. In other \ |
|
142
|
|
|
words, you can only call `continue_training` from within this script.""" |
|
143
|
|
|
|
|
144
|
|
|
|
|
145
|
|
|
def dump(object_, file_, parameters=None, use_cpickle=False, |
|
146
|
|
|
protocol=DEFAULT_PROTOCOL, **kwargs): |
|
147
|
|
|
r"""Pickles an object, optionally saving its parameters separately. |
|
148
|
|
|
|
|
149
|
|
|
Parameters |
|
150
|
|
|
---------- |
|
151
|
|
|
object_ : object |
|
152
|
|
|
The object to pickle. If None, only the parameters passed to the |
|
153
|
|
|
`parameters` argument will be saved. |
|
154
|
|
|
file_ : file |
|
155
|
|
|
The destination for saving. |
|
156
|
|
|
parameters : list, optional |
|
157
|
|
|
Shared variables whose internal numpy arrays should be saved |
|
158
|
|
|
separately in the `_parameters` field of the tar file. |
|
159
|
|
|
pickle_object : bool |
|
160
|
|
|
If False, `object_` will not be serialized, only its parameters. |
|
161
|
|
|
This flag can be used when `object_` is not serializable, but one |
|
162
|
|
|
still want to save its parameters. Default: True |
|
163
|
|
|
use_cpickle : bool |
|
164
|
|
|
Use cPickle instead of pickle. Setting it to true will disable the |
|
165
|
|
|
warning message if you try to pickle objects from the main module, |
|
166
|
|
|
so be sure that there is no warning before turning this flag |
|
167
|
|
|
on. Default: False. |
|
168
|
|
|
protocol : int, optional |
|
169
|
|
|
The pickling protocol to use. Unlike Python's built-in pickle, the |
|
170
|
|
|
default is set to `2` instead of 0 for Python 2. The Python 3 |
|
171
|
|
|
default (level 3) is maintained. |
|
172
|
|
|
\*\*kwargs |
|
173
|
|
|
Keyword arguments to be passed to `pickle.Pickler`. |
|
174
|
|
|
|
|
175
|
|
|
""" |
|
176
|
|
|
if use_cpickle: |
|
177
|
|
|
pickler = cPickle.Pickler |
|
178
|
|
|
else: |
|
179
|
|
|
pickler = _PicklerWithWarning |
|
180
|
|
|
with closing(tarfile.TarFile(fileobj=file_, mode='w')) as tar_file: |
|
181
|
|
|
external_objects = {} |
|
182
|
|
|
|
|
183
|
|
|
def _save_parameters(f): |
|
184
|
|
|
renamer = _Renamer() |
|
185
|
|
|
named_parameters = {renamer(p): p for p in parameters} |
|
186
|
|
|
numpy.savez(f, **{n: p.get_value() |
|
187
|
|
|
for n, p in named_parameters.items()}) |
|
188
|
|
|
for n, p in named_parameters.items(): |
|
189
|
|
|
array_ = p.container.storage[0] |
|
190
|
|
|
external_objects[id(array_)] = _mangle_parameter_name( |
|
191
|
|
|
type(array_), n) |
|
192
|
|
|
if parameters: |
|
193
|
|
|
_taradd(_save_parameters, tar_file, '_parameters') |
|
194
|
|
|
if object_ is not None: |
|
195
|
|
|
save_object = _SaveObject(pickler, object_, external_objects, |
|
196
|
|
|
protocol, **kwargs) |
|
197
|
|
|
_taradd(save_object, tar_file, '_pkl') |
|
198
|
|
|
|
|
199
|
|
|
|
|
200
|
|
|
def secure_dump(object_, path, dump_function=dump, **kwargs): |
|
201
|
|
|
r"""Robust serialization - does not corrupt your files when failed. |
|
202
|
|
|
|
|
203
|
|
|
Parameters |
|
204
|
|
|
---------- |
|
205
|
|
|
object_ : object |
|
206
|
|
|
The object to be saved to the disk. |
|
207
|
|
|
path : str |
|
208
|
|
|
The destination for saving. |
|
209
|
|
|
dump_function : function |
|
210
|
|
|
The function that is used to perform the serialization. Must take |
|
211
|
|
|
an object and file object as arguments. By default, :func:`dump` is |
|
212
|
|
|
used. An alternative would be :func:`pickle.dump`. |
|
213
|
|
|
\*\*kwargs |
|
214
|
|
|
Keyword arguments to be passed to `dump_function`. |
|
215
|
|
|
|
|
216
|
|
|
""" |
|
217
|
|
|
try: |
|
218
|
|
|
with tempfile.NamedTemporaryFile(delete=False, |
|
219
|
|
|
dir=config.temp_dir) as temp: |
|
220
|
|
|
dump_function(object_, temp, **kwargs) |
|
221
|
|
|
shutil.move(temp.name, path) |
|
222
|
|
|
except: |
|
223
|
|
|
if "temp" in locals(): |
|
224
|
|
|
os.remove(temp.name) |
|
225
|
|
|
raise |
|
226
|
|
|
|
|
227
|
|
|
|
|
228
|
|
|
def load(file_, name='_pkl', use_cpickle=False): |
|
229
|
|
|
"""Loads an object saved using the `dump` function. |
|
230
|
|
|
|
|
231
|
|
|
By default, this function loads the object saved by the `dump` |
|
232
|
|
|
function. If some objects have been added to the archive using the |
|
233
|
|
|
`add_to_dump` function, then you can load them by passing their name |
|
234
|
|
|
to the `name` parameter. |
|
235
|
|
|
|
|
236
|
|
|
Parameters |
|
237
|
|
|
---------- |
|
238
|
|
|
file_ : file |
|
239
|
|
|
The file that contains the object to load. |
|
240
|
|
|
name : str |
|
241
|
|
|
Name of the object to load. Default is `_pkl`, meaning that it is |
|
242
|
|
|
the original object which have been dumped that is loaded. |
|
243
|
|
|
use_cpickle : bool |
|
244
|
|
|
Use cPickle instead of pickle. Default: False. |
|
245
|
|
|
|
|
246
|
|
|
Returns |
|
247
|
|
|
------- |
|
248
|
|
|
The object saved in file_. |
|
249
|
|
|
|
|
250
|
|
|
""" |
|
251
|
|
|
file_.seek(0) # To be able to read several objects in one file |
|
252
|
|
|
if use_cpickle: |
|
253
|
|
|
unpickler = cPickle.Unpickler |
|
254
|
|
|
else: |
|
255
|
|
|
unpickler = pickle.Unpickler |
|
256
|
|
|
with tarfile.open(fileobj=file_, mode='r') as tar_file: |
|
257
|
|
|
p = unpickler( |
|
258
|
|
|
tar_file.extractfile(tar_file.getmember(name))) |
|
259
|
|
|
if '_parameters' in tar_file.getnames(): |
|
260
|
|
|
p.persistent_load = _PersistentLoad(tar_file) |
|
261
|
|
|
return p.load() |
|
262
|
|
|
|
|
263
|
|
|
|
|
264
|
|
|
def load_parameters(file_): |
|
265
|
|
|
"""Loads the parameter values saved by :func:`dump`. |
|
266
|
|
|
|
|
267
|
|
|
This functions loads the parameters that have been saved separately by |
|
268
|
|
|
:func:`dump`, ie the ones given to its parameter `parameters`. |
|
269
|
|
|
|
|
270
|
|
|
Parameters |
|
271
|
|
|
---------- |
|
272
|
|
|
file_ : file |
|
273
|
|
|
The source to load the parameters from. |
|
274
|
|
|
|
|
275
|
|
|
Returns |
|
276
|
|
|
------- |
|
277
|
|
|
A dictionary of (parameter name, numpy array) pairs. |
|
278
|
|
|
|
|
279
|
|
|
""" |
|
280
|
|
|
with closing(_load_parameters_npzfile(file_)) as npz_file: |
|
281
|
|
|
return {name.replace(BRICK_DELIMITER, '/'): value |
|
282
|
|
|
for name, value in npz_file.items()} |
|
283
|
|
|
|
|
284
|
|
|
|
|
285
|
|
|
def add_to_dump(object_, file_, name, parameters=None, use_cpickle=False, |
|
286
|
|
|
protocol=DEFAULT_PROTOCOL, **kwargs): |
|
287
|
|
|
r"""Pickles an object to an existing tar archive. |
|
288
|
|
|
|
|
289
|
|
|
This function allows to dump more objects to an existing archive. If |
|
290
|
|
|
the object you want to dump posesses the same set of shared variables |
|
291
|
|
|
as the object already dumped, you can pass them to the `parameters` |
|
292
|
|
|
argument, which will avoid them to be serialized a second time. |
|
293
|
|
|
However, it won't work if the shared variable you pass to the |
|
294
|
|
|
`parameters` argument are not already in the archive. |
|
295
|
|
|
|
|
296
|
|
|
Parameters |
|
297
|
|
|
---------- |
|
298
|
|
|
object_ : object |
|
299
|
|
|
The object to pickle. |
|
300
|
|
|
file_ : file |
|
301
|
|
|
The destination for saving, opened in read-write mode (`r+`). |
|
302
|
|
|
name : str |
|
303
|
|
|
The name of the object you are dumping. It will be used as a file |
|
304
|
|
|
name in the archive. '_pkl' and '_paramters' are reserved names |
|
305
|
|
|
and can't be used. |
|
306
|
|
|
parameters : list, optional |
|
307
|
|
|
Shared variables whose internal numpy arrays should be saved |
|
308
|
|
|
separately in the `_parameters` field of the tar file. Must be a |
|
309
|
|
|
subset of the parameters already in the archive. |
|
310
|
|
|
use_cpickle : bool |
|
311
|
|
|
Use cPickle instead of pickle. Setting it to true will disable the |
|
312
|
|
|
warning message if you try to pickle objects from the main module! |
|
313
|
|
|
Be sure that you don't have the warning before turning this flag |
|
314
|
|
|
on. Default: False. |
|
315
|
|
|
protocol : int, optional |
|
316
|
|
|
The pickling protocol to use. Unlike Python's built-in pickle, the |
|
317
|
|
|
default is set to `2` instead of 0 for Python 2. The Python 3 |
|
318
|
|
|
default (level 3) is maintained. |
|
319
|
|
|
\*\*kwargs |
|
320
|
|
|
Keyword arguments to be passed to `pickle.Pickler`. |
|
321
|
|
|
|
|
322
|
|
|
""" |
|
323
|
|
|
if name in ['_pkl', '_parameters']: |
|
324
|
|
|
raise ValueError("_pkl and _parameters are reserved names and can't" \ |
|
325
|
|
|
" be used as name for your object.") |
|
326
|
|
|
|
|
327
|
|
|
external_parameters = {} |
|
328
|
|
|
if parameters is not None: |
|
329
|
|
|
renamer = _Renamer() |
|
330
|
|
|
named_parameters = {renamer(p): p for p in parameters} |
|
331
|
|
|
for n, p in named_parameters.items(): |
|
332
|
|
|
array_ = p.container.storage[0] |
|
333
|
|
|
external_parameters[id(array_)] = _mangle_parameter_name( |
|
334
|
|
|
type(array_), n) |
|
335
|
|
|
|
|
336
|
|
|
# Check that the parameters are the same that the ones in the archive. |
|
337
|
|
|
file_.seek(0) # To be able to read what is in the tar file already. |
|
338
|
|
|
with closing(tarfile.TarFile(fileobj=file_, mode='r')) as tar_file: |
|
339
|
|
|
if '_parameters' not in tar_file.getnames(): |
|
340
|
|
|
raise ValueError("There is no parameters in the archive, so" \ |
|
341
|
|
|
" you can't use the argument parameters.") |
|
342
|
|
|
else: |
|
343
|
|
|
parameters = numpy.load( |
|
344
|
|
|
tar_file.extractfile(tar_file.getmember('_parameters'))) |
|
345
|
|
|
s1 = set(parameters.keys()) |
|
346
|
|
|
s2 = [_unmangle_parameter_name(x)[1] for x in |
|
347
|
|
|
external_parameters.values()] |
|
348
|
|
|
if not s1.issuperset(s2): |
|
349
|
|
|
raise ValueError('The set of parameters is different' \ |
|
350
|
|
|
' from the one in the archive.') |
|
351
|
|
|
|
|
352
|
|
|
if use_cpickle: |
|
353
|
|
|
pickler = cPickle.Pickler |
|
354
|
|
|
else: |
|
355
|
|
|
pickler = _PicklerWithWarning |
|
356
|
|
|
file_.seek(0) # To be able to add new things in the tar file. |
|
357
|
|
|
with closing(tarfile.TarFile(fileobj=file_, mode='a')) as tar_file: |
|
358
|
|
|
save_object = _SaveObject(pickler, object_, external_parameters, |
|
359
|
|
|
protocol, **kwargs) |
|
360
|
|
|
_taradd(save_object, tar_file, name) |
|
361
|
|
|
|
|
362
|
|
|
|
|
363
|
|
|
def continue_training(path): |
|
364
|
|
|
"""Continues training using checkpoint. |
|
365
|
|
|
|
|
366
|
|
|
Parameters |
|
367
|
|
|
---------- |
|
368
|
|
|
path : str |
|
369
|
|
|
Path to checkpoint. |
|
370
|
|
|
|
|
371
|
|
|
Notes |
|
372
|
|
|
----- |
|
373
|
|
|
Python picklers can unpickle objects from global namespace only if |
|
374
|
|
|
they are present in namespace where unpickling happens. Often global |
|
375
|
|
|
functions are needed for mapping, filtering and other data stream |
|
376
|
|
|
operations. In a case if the main loop uses global objects and |
|
377
|
|
|
this function fails with a message like |
|
378
|
|
|
``` |
|
379
|
|
|
AttributeError: 'module' object has no attribute '...' |
|
380
|
|
|
``` |
|
381
|
|
|
it means that you need to import these objects. |
|
382
|
|
|
|
|
383
|
|
|
Examples |
|
384
|
|
|
-------- |
|
385
|
|
|
This function can be used in two ways: in your script where a main |
|
386
|
|
|
loop defined or in a different script. For later options see Notes |
|
387
|
|
|
section. |
|
388
|
|
|
|
|
389
|
|
|
""" |
|
390
|
|
|
with change_recursion_limit(config.recursion_limit): |
|
391
|
|
|
with open(path, "rb") as f: |
|
392
|
|
|
main_loop = load(f) |
|
393
|
|
|
main_loop.run() |
|
394
|
|
|
|
|
395
|
|
|
|
|
396
|
|
|
def dump_and_add_to_dump(object_, file_, parameters=None, to_add=None, |
|
397
|
|
|
use_cpickle=False, protocol=DEFAULT_PROTOCOL, |
|
398
|
|
|
**kwargs): |
|
399
|
|
|
r"""Calls both `dump` and `add_to_dump` to serialze several objects. |
|
400
|
|
|
|
|
401
|
|
|
This function is used to serialize several at the same time, using |
|
402
|
|
|
persistent ID. Its main advantage is that it can be used with |
|
403
|
|
|
`secure_dump`. |
|
404
|
|
|
|
|
405
|
|
|
Parameters |
|
406
|
|
|
---------- |
|
407
|
|
|
object_ : object |
|
408
|
|
|
The object to pickle. If None, only the parameters passed to the |
|
409
|
|
|
`parameters` argument will be saved. |
|
410
|
|
|
file_ : file |
|
411
|
|
|
The destination for saving. |
|
412
|
|
|
parameters : list, optional |
|
413
|
|
|
Shared variables whose internal numpy arrays should be saved |
|
414
|
|
|
separately in the `_parameters` field of the tar file. |
|
415
|
|
|
to_add : dict of objects |
|
416
|
|
|
A {'name': object} dictionnary of additional objects to save in |
|
417
|
|
|
the tar archive. Its keys will be used as name in the tar file. |
|
418
|
|
|
use_cpickle : bool |
|
419
|
|
|
Use cPickle instead of pickle. Setting it to true will disable the |
|
420
|
|
|
warning message if you try to pickle objects from the main module, |
|
421
|
|
|
so be sure that there is no warning before turning this flag |
|
422
|
|
|
on. Default: False. |
|
423
|
|
|
protocol : int, optional |
|
424
|
|
|
The pickling protocol to use. Unlike Python's built-in pickle, the |
|
425
|
|
|
default is set to `2` instead of 0 for Python 2. The Python 3 |
|
426
|
|
|
default (level 3) is maintained. |
|
427
|
|
|
\*\*kwargs |
|
428
|
|
|
Keyword arguments to be passed to `pickle.Pickler`. |
|
429
|
|
|
|
|
430
|
|
|
""" |
|
431
|
|
|
dump(object_, file_, parameters=parameters, use_cpickle=use_cpickle, |
|
432
|
|
|
protocol=protocol, **kwargs) |
|
433
|
|
|
if to_add is not None: |
|
434
|
|
|
for name, obj in six.iteritems(to_add): |
|
435
|
|
|
add_to_dump(obj, file_, name, parameters=parameters, |
|
436
|
|
|
use_cpickle=use_cpickle, protocol=protocol, **kwargs) |
|
437
|
|
|
|
|
438
|
|
|
|
|
439
|
|
|
class _PicklerWithWarning(_Pickler): |
|
440
|
|
|
"""Pickler that adds a warning message. |
|
441
|
|
|
|
|
442
|
|
|
Adds a warning message if we try to save an object referenced in the |
|
443
|
|
|
main module. |
|
444
|
|
|
|
|
445
|
|
|
""" |
|
446
|
|
|
dispatch = _Pickler.dispatch.copy() |
|
447
|
|
|
|
|
448
|
|
|
def save_global(self, obj, name=None, **kwargs): |
|
449
|
|
|
module = getattr(obj, '__module__', None) |
|
450
|
|
|
if module == '__main__': |
|
451
|
|
|
warnings.warn( |
|
452
|
|
|
MAIN_MODULE_WARNING.format(kwargs.get('name', obj.__name__)) |
|
453
|
|
|
) |
|
454
|
|
|
_Pickler.save_global(self, obj, name=name, **kwargs) |
|
455
|
|
|
|
|
456
|
|
|
dispatch[six.types.FunctionType] = save_global |
|
457
|
|
|
if six.PY2: |
|
458
|
|
|
dispatch[six.types.ClassType] = save_global |
|
|
|
|
|
|
459
|
|
|
dispatch[six.types.BuiltinFunctionType] = save_global |
|
460
|
|
|
dispatch[six.types.TypeType] = save_global |
|
|
|
|
|
|
461
|
|
|
|
|
462
|
|
|
|
|
463
|
|
|
class _SaveObject(object): |
|
464
|
|
|
r"""Saves an object using Persistent ID. |
|
465
|
|
|
|
|
466
|
|
|
Parameters |
|
467
|
|
|
---------- |
|
468
|
|
|
pickler : object |
|
469
|
|
|
The pickler to use |
|
470
|
|
|
object_ : object |
|
471
|
|
|
The object to pickle. |
|
472
|
|
|
external_objects : dict of object |
|
473
|
|
|
The external objects to save using persistent id. |
|
474
|
|
|
protocol : int, optional |
|
475
|
|
|
The pickling protocol to use. |
|
476
|
|
|
\*\*kwargs |
|
477
|
|
|
Keyword arguments to be passed to `pickle.Pickler`. |
|
478
|
|
|
|
|
479
|
|
|
""" |
|
480
|
|
|
def __init__(self, pickler, object_, external_objects, protocol, **kwargs): |
|
481
|
|
|
self.pickler = pickler |
|
482
|
|
|
self.object_ = object_ |
|
483
|
|
|
self.external_objects = external_objects |
|
484
|
|
|
self.protocol = protocol |
|
485
|
|
|
self.kwargs = kwargs |
|
486
|
|
|
|
|
487
|
|
|
def __call__(self, f): |
|
488
|
|
|
p = self.pickler(f, protocol=self.protocol, **self.kwargs) |
|
489
|
|
|
p.persistent_id = _PersistentID(self.external_objects) |
|
490
|
|
|
p.dump(self.object_) |
|
491
|
|
|
|
|
492
|
|
|
|
|
493
|
|
|
class _Renamer(object): |
|
494
|
|
|
"""Returns a new name for the given parameter. |
|
495
|
|
|
|
|
496
|
|
|
It maintains a list of names already used to avoid naming |
|
497
|
|
|
collisions. It also provides names for variables without |
|
498
|
|
|
names. |
|
499
|
|
|
|
|
500
|
|
|
Attributes |
|
501
|
|
|
---------- |
|
502
|
|
|
used_names : set |
|
503
|
|
|
The set of names already taken. |
|
504
|
|
|
default_name : str |
|
505
|
|
|
The name to use if a parameter doesn't have a name. Default: |
|
506
|
|
|
'parameter'. |
|
507
|
|
|
|
|
508
|
|
|
""" |
|
509
|
|
|
def __init__(self): |
|
510
|
|
|
self.used_names = set() |
|
511
|
|
|
self.default_name = 'parameter' |
|
512
|
|
|
|
|
513
|
|
|
def __call__(self, parameter): |
|
514
|
|
|
# Standard Blocks parameter |
|
515
|
|
|
if get_brick(parameter) is not None: |
|
516
|
|
|
name = '{}.{}'.format( |
|
517
|
|
|
BRICK_DELIMITER.join( |
|
518
|
|
|
[""] + [brick.name for brick in |
|
519
|
|
|
get_brick(parameter).get_unique_path()]), |
|
520
|
|
|
parameter.name) |
|
521
|
|
|
# Shared variables with tag.name |
|
522
|
|
|
elif hasattr(parameter.tag, 'name'): |
|
523
|
|
|
name = parameter.tag.name |
|
524
|
|
|
# Standard shared variable |
|
525
|
|
|
elif parameter.name is not None: |
|
526
|
|
|
name = parameter.name |
|
527
|
|
|
# Variables without names |
|
528
|
|
|
else: |
|
529
|
|
|
name = self.default_name |
|
530
|
|
|
# Handle naming collisions |
|
531
|
|
|
if name in self.used_names: |
|
532
|
|
|
i = 2 |
|
533
|
|
|
new_name = '_'.join([name, str(i)]) |
|
534
|
|
|
while new_name in self.used_names: |
|
535
|
|
|
i += 1 |
|
536
|
|
|
new_name = '_'.join([name, str(i)]) |
|
537
|
|
|
name = new_name |
|
538
|
|
|
self.used_names.add(name) |
|
539
|
|
|
return name |
|
540
|
|
|
|
|
541
|
|
|
|
|
542
|
|
|
_ARRAY_TYPE_MAP = {numpy.ndarray: 'numpy_ndarray'} |
|
543
|
|
|
_INVERSE_ARRAY_TYPE_MAP = {'numpy_ndarray': numpy.array} |
|
544
|
|
|
if cuda_ndarray: |
|
545
|
|
|
_ARRAY_TYPE_MAP[cuda_ndarray.cuda_ndarray.CudaNdarray] = 'cuda_ndarray' |
|
546
|
|
|
_INVERSE_ARRAY_TYPE_MAP['cuda_ndarray'] = \ |
|
547
|
|
|
cuda_ndarray.cuda_ndarray.CudaNdarray |
|
548
|
|
|
|
|
549
|
|
|
|
|
550
|
|
|
class _PersistentID(object): |
|
551
|
|
|
"""Returns persistent identifiers for objects saved separately.""" |
|
552
|
|
|
def __init__(self, external_objects): |
|
553
|
|
|
self.external_objects = external_objects |
|
554
|
|
|
|
|
555
|
|
|
def __call__(self, object_): |
|
556
|
|
|
return self.external_objects.get(id(object_)) |
|
557
|
|
|
|
|
558
|
|
|
|
|
559
|
|
|
class _PersistentLoad(object): |
|
560
|
|
|
"""Loads object saved using a PersistentID mechanism.""" |
|
561
|
|
|
def __init__(self, tar_file): |
|
562
|
|
|
self.tar_file = tar_file |
|
563
|
|
|
if '_parameters' in tar_file.getnames(): |
|
564
|
|
|
self.parameters = numpy.load( |
|
565
|
|
|
tar_file.extractfile(tar_file.getmember('_parameters'))) |
|
566
|
|
|
|
|
567
|
|
|
def __call__(self, id_): |
|
568
|
|
|
components = _unmangle_parameter_name(id_) |
|
569
|
|
|
return components[0](self.parameters[components[1]]) |
|
570
|
|
|
|
|
571
|
|
|
|
|
572
|
|
|
def _mangle_parameter_name(type_, name): |
|
573
|
|
|
return '#{}.{}'.format(_ARRAY_TYPE_MAP[type_], name) |
|
574
|
|
|
|
|
575
|
|
|
|
|
576
|
|
|
def _unmangle_parameter_name(mangled_name): |
|
577
|
|
|
if mangled_name.startswith('#'): |
|
578
|
|
|
type_, name = mangled_name[1:].split('.', 1) |
|
579
|
|
|
return _INVERSE_ARRAY_TYPE_MAP[type_], name |
|
580
|
|
|
|
|
581
|
|
|
|
|
582
|
|
|
def _taradd(func, tar_file, name): |
|
583
|
|
|
"""Adds elements dumped by the function `func` to a tar_file. |
|
584
|
|
|
|
|
585
|
|
|
This functions first calls the function `func` and add the file that |
|
586
|
|
|
`func` dumps to the achive `tar_file`, under the name `name`. |
|
587
|
|
|
|
|
588
|
|
|
Parameters |
|
589
|
|
|
---------- |
|
590
|
|
|
func : function |
|
591
|
|
|
The dumping function. |
|
592
|
|
|
tar_file : file |
|
593
|
|
|
The archive that we are filling. |
|
594
|
|
|
name : str |
|
595
|
|
|
The name of the dumped file in the archive. |
|
596
|
|
|
|
|
597
|
|
|
""" |
|
598
|
|
|
with tempfile.NamedTemporaryFile('wb', delete=False) as temp_file: |
|
599
|
|
|
func(temp_file) |
|
600
|
|
|
temp_file.close() |
|
601
|
|
|
tar_file.add(temp_file.name, arcname=name) |
|
602
|
|
|
if os.path.isfile(temp_file.name): |
|
603
|
|
|
os.remove(temp_file.name) |
|
604
|
|
|
|
|
605
|
|
|
|
|
606
|
|
|
def _load_parameters_npzfile(file_): |
|
607
|
|
|
"""Loads parameters from a .npz file in a tar archive.""" |
|
608
|
|
|
with tarfile.open(fileobj=file_, mode='r') as tar_file: |
|
609
|
|
|
return numpy.load( |
|
610
|
|
|
tar_file.extractfile(tar_file.getmember('_parameters'))) |
|
611
|
|
|
|
This check looks for calls to members that are non-existent. These calls will fail.
The member could have been renamed or removed.