Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

deepy.utils.fill_parameters()   F

Complexity

Conditions 15

Size

Total Lines 32

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 15
dl 0
loc 32
rs 2.7451

How to fix   Complexity   

Complexity

Complex classes like deepy.utils.fill_parameters() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
from functions import FLOATX
6
import os, gzip, cPickle as pickle
7
import logging as loggers
8
logging = loggers.getLogger(__name__)
9
10
def create_var(theano_tensor, dim=0, test_shape=None, test_dtype=FLOATX):
11
    """
12
    Wrap a Theano tensor into the variable for defining neural network.
13
    :param dim: last dimension of tensor, 0 indicates that the last dimension is flexible
14
    :rtype: TensorVar
15
    """
16
    from deepy.layers.var import NeuralVariable
17
    var = NeuralVariable(theano_tensor, dim=dim)
18
    if test_shape:
19
        if type(test_shape) != list and type(test_shape) != tuple:
20
            var.set_test_value(test_shape)
21
        else:
22
            var.set_test_value(np.random.rand(*test_shape).astype(test_dtype))
23
    return var
24
25
def fill_parameters(path, networks, exclude_free_params=False, check_parameters=False):
26
        """
27
        Load parameters from file to fill all network sequentially.
28
        """
29
        if not os.path.exists(path):
30
            raise Exception("model {} does not exist".format(path))
31
        # Decide which parameters to load
32
        normal_params = sum([nn.parameters for nn in networks], [])
33
        all_params = sum([nn.all_parameters for nn in networks], [])
34
        # Load parameters
35
        if path.endswith(".gz"):
36
            opener = gzip.open if path.lower().endswith('.gz') else open
37
            handle = opener(path, 'rb')
38
            saved_params = pickle.load(handle)
39
            handle.close()
40
            # Write parameters
41
            if len(all_params) != len(saved_params):
42
                logging.warning("parameters in the network: {}, parameters in the dumped model: {}".format(len(all_params), len(saved_params)))
43
            for target, source in zip(all_params, saved_params):
44
                if not exclude_free_params or target not in normal_params:
45
                    target.set_value(source)
46
        elif path.endswith(".npz"):
47
            arrs = np.load(path)
48
            # Write parameters
49
            if len(all_params) != len(arrs.keys()):
50
                logging.warning("parameters in the network: {}, parameters in the dumped model: {}".format(len(all_params), len(arrs.keys())))
51
            for target, idx in zip(all_params, range(len(arrs.keys()))):
52
                if not exclude_free_params or target not in normal_params:
53
                    source = arrs['arr_%d' % idx]
54
                    target.set_value(source)
55
        else:
56
            raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)