1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
import numpy as np |
5
|
|
|
from functions import FLOATX |
6
|
|
|
import os, gzip, cPickle as pickle |
7
|
|
|
import logging as loggers |
8
|
|
|
logging = loggers.getLogger(__name__) |
9
|
|
|
|
10
|
|
|
def var(theano_tensor, dim=0, test_shape=None, test_dtype=FLOATX): |
11
|
|
|
""" |
12
|
|
|
A shortcut of create_var |
13
|
|
|
""" |
14
|
|
|
return create_var(theano_tensor, dim=0, test_shape=None, test_dtype=FLOATX) |
15
|
|
|
|
16
|
|
|
def create_var(theano_tensor, dim=0, test_shape=None, test_dtype=FLOATX): |
17
|
|
|
""" |
18
|
|
|
Wrap a Theano tensor into the variable for defining neural network. |
19
|
|
|
:param dim: last dimension of tensor, 0 indicates that the last dimension is flexible |
20
|
|
|
:rtype: TensorVar |
21
|
|
|
""" |
22
|
|
|
from deepy.layers.var import NeuralVariable |
23
|
|
|
var = NeuralVariable(theano_tensor, dim=dim) |
24
|
|
|
if test_shape: |
25
|
|
|
if type(test_shape) != list and type(test_shape) != tuple: |
26
|
|
|
var.set_test_value(test_shape) |
27
|
|
|
else: |
28
|
|
|
var.set_test_value(np.random.rand(*test_shape).astype(test_dtype)) |
29
|
|
|
return var |
30
|
|
|
|
31
|
|
|
def fill_parameters(path, networks, exclude_free_params=False, check_parameters=False): |
32
|
|
|
""" |
33
|
|
|
Load parameters from file to fill all network sequentially. |
34
|
|
|
""" |
35
|
|
|
if not os.path.exists(path): |
36
|
|
|
raise Exception("model {} does not exist".format(path)) |
37
|
|
|
# Decide which parameters to load |
38
|
|
|
normal_params = sum([nn.parameters for nn in networks], []) |
39
|
|
|
all_params = sum([nn.all_parameters for nn in networks], []) |
40
|
|
|
# Load parameters |
41
|
|
|
if path.endswith(".gz"): |
42
|
|
|
opener = gzip.open if path.lower().endswith('.gz') else open |
43
|
|
|
handle = opener(path, 'rb') |
44
|
|
|
saved_params = pickle.load(handle) |
45
|
|
|
handle.close() |
46
|
|
|
# Write parameters |
47
|
|
|
if len(all_params) != len(saved_params): |
48
|
|
|
logging.warning("parameters in the network: {}, parameters in the dumped model: {}".format(len(all_params), len(saved_params))) |
49
|
|
|
for target, source in zip(all_params, saved_params): |
50
|
|
|
if not exclude_free_params or target not in normal_params: |
51
|
|
|
target.set_value(source) |
52
|
|
|
elif path.endswith(".npz"): |
53
|
|
|
arrs = np.load(path) |
54
|
|
|
# Write parameters |
55
|
|
|
if len(all_params) != len(arrs.keys()): |
56
|
|
|
logging.warning("parameters in the network: {}, parameters in the dumped model: {}".format(len(all_params), len(arrs.keys()))) |
57
|
|
|
for target, idx in zip(all_params, range(len(arrs.keys()))): |
58
|
|
|
if not exclude_free_params or target not in normal_params: |
59
|
|
|
source = arrs['arr_%d' % idx] |
60
|
|
|
target.set_value(source) |
61
|
|
|
else: |
62
|
|
|
raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path) |