Completed
Push — master ( dccd0d...37cade )
by Raphael
01:33
created

var()   A

Complexity

Conditions 1

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 5
rs 9.4285
c 0
b 0
f 0
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
from functions import FLOATX
6
import os, gzip, cPickle as pickle
7
import logging as loggers
8
logging = loggers.getLogger(__name__)
9
10
def var(theano_tensor, dim=0, test_shape=None, test_dtype=FLOATX):
11
    """
12
    A shortcut of create_var
13
    """
14
    return create_var(theano_tensor, dim=0, test_shape=None, test_dtype=FLOATX)
15
16
def create_var(theano_tensor, dim=0, test_shape=None, test_dtype=FLOATX):
17
    """
18
    Wrap a Theano tensor into the variable for defining neural network.
19
    :param dim: last dimension of tensor, 0 indicates that the last dimension is flexible
20
    :rtype: TensorVar
21
    """
22
    from deepy.layers.var import NeuralVariable
23
    var = NeuralVariable(theano_tensor, dim=dim)
24
    if test_shape:
25
        if type(test_shape) != list and type(test_shape) != tuple:
26
            var.set_test_value(test_shape)
27
        else:
28
            var.set_test_value(np.random.rand(*test_shape).astype(test_dtype))
29
    return var
30
31
def fill_parameters(path, networks, exclude_free_params=False, check_parameters=False):
32
        """
33
        Load parameters from file to fill all network sequentially.
34
        """
35
        if not os.path.exists(path):
36
            raise Exception("model {} does not exist".format(path))
37
        # Decide which parameters to load
38
        normal_params = sum([nn.parameters for nn in networks], [])
39
        all_params = sum([nn.all_parameters for nn in networks], [])
40
        # Load parameters
41
        if path.endswith(".gz"):
42
            opener = gzip.open if path.lower().endswith('.gz') else open
43
            handle = opener(path, 'rb')
44
            saved_params = pickle.load(handle)
45
            handle.close()
46
            # Write parameters
47
            if len(all_params) != len(saved_params):
48
                logging.warning("parameters in the network: {}, parameters in the dumped model: {}".format(len(all_params), len(saved_params)))
49
            for target, source in zip(all_params, saved_params):
50
                if not exclude_free_params or target not in normal_params:
51
                    target.set_value(source)
52
        elif path.endswith(".npz"):
53
            arrs = np.load(path)
54
            # Write parameters
55
            if len(all_params) != len(arrs.keys()):
56
                logging.warning("parameters in the network: {}, parameters in the dumped model: {}".format(len(all_params), len(arrs.keys())))
57
            for target, idx in zip(all_params, range(len(arrs.keys()))):
58
                if not exclude_free_params or target not in normal_params:
59
                    source = arrs['arr_%d' % idx]
60
                    target.set_value(source)
61
        else:
62
            raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)