|
1
|
|
|
#!/usr/bin/env python |
|
2
|
|
|
# -*- coding: utf-8 -*- |
|
3
|
|
|
|
|
4
|
|
|
import numpy as np |
|
5
|
|
|
from functions import FLOATX |
|
6
|
|
|
import os, gzip, cPickle as pickle |
|
7
|
|
|
import logging as loggers |
|
8
|
|
|
logging = loggers.getLogger(__name__) |
|
9
|
|
|
|
|
10
|
|
|
def create_var(theano_tensor, dim=0, test_shape=None, test_dtype=FLOATX): |
|
11
|
|
|
""" |
|
12
|
|
|
Wrap a Theano tensor into the variable for defining neural network. |
|
13
|
|
|
:param dim: last dimension of tensor, 0 indicates that the last dimension is flexible |
|
14
|
|
|
:rtype: TensorVar |
|
15
|
|
|
""" |
|
16
|
|
|
from deepy.layers.var import NeuralVariable |
|
17
|
|
|
var = NeuralVariable(theano_tensor, dim=dim) |
|
18
|
|
|
if test_shape: |
|
19
|
|
|
if type(test_shape) != list and type(test_shape) != tuple: |
|
20
|
|
|
var.set_test_value(test_shape) |
|
21
|
|
|
else: |
|
22
|
|
|
var.set_test_value(np.random.rand(*test_shape).astype(test_dtype)) |
|
23
|
|
|
return var |
|
24
|
|
|
|
|
25
|
|
|
def fill_parameters(path, networks, exclude_free_params=False, check_parameters=False): |
|
26
|
|
|
""" |
|
27
|
|
|
Load parameters from file to fill all network sequentially. |
|
28
|
|
|
""" |
|
29
|
|
|
if not os.path.exists(path): |
|
30
|
|
|
raise Exception("model {} does not exist".format(path)) |
|
31
|
|
|
# Decide which parameters to load |
|
32
|
|
|
normal_params = sum([nn.parameters for nn in networks], []) |
|
33
|
|
|
all_params = sum([nn.all_parameters for nn in networks], []) |
|
34
|
|
|
# Load parameters |
|
35
|
|
|
if path.endswith(".gz"): |
|
36
|
|
|
opener = gzip.open if path.lower().endswith('.gz') else open |
|
37
|
|
|
handle = opener(path, 'rb') |
|
38
|
|
|
saved_params = pickle.load(handle) |
|
39
|
|
|
handle.close() |
|
40
|
|
|
# Write parameters |
|
41
|
|
|
if len(all_params) != len(saved_params): |
|
42
|
|
|
logging.warning("parameters in the network: {}, parameters in the dumped model: {}".format(len(all_params), len(saved_params))) |
|
43
|
|
|
for target, source in zip(all_params, saved_params): |
|
44
|
|
|
if not exclude_free_params or target not in normal_params: |
|
45
|
|
|
target.set_value(source) |
|
46
|
|
|
elif path.endswith(".npz"): |
|
47
|
|
|
arrs = np.load(path) |
|
48
|
|
|
# Write parameters |
|
49
|
|
|
if len(all_params) != len(arrs.keys()): |
|
50
|
|
|
logging.warning("parameters in the network: {}, parameters in the dumped model: {}".format(len(all_params), len(arrs.keys()))) |
|
51
|
|
|
for target, idx in zip(all_params, range(len(arrs.keys()))): |
|
52
|
|
|
if not exclude_free_params or target not in normal_params: |
|
53
|
|
|
source = arrs['arr_%d' % idx] |
|
54
|
|
|
target.set_value(source) |
|
55
|
|
|
else: |
|
56
|
|
|
raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path) |