Completed
Push — master ( 394368...090fba )
by Raphael
01:33
created

GraphBuilder.create_vars_from_data()   D

Complexity

Conditions 8

Size

Total Lines 36

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 8
c 0
b 0
f 0
dl 0
loc 36
rs 4
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import os
5
import numpy as np
6
import pickle
7
import gzip
8
from inspect import getargspec
9
from env import env
10
import theano
11
import theano.tensor as TT
12
import logging as loggers
13
from tensor_conversion import neural_computation
14
from disconnected_grad import disconnected_grad
15
from deepy.utils import Scanner
16
logging = loggers.getLogger(__name__)
17
18
19
class GraphBuilder(object):
20
    """
21
    Tool for creating computational graph in deepy.
22
    """
23
24
    def __init__(self):
25
        self._default_block = self.new_block("default_block")
26
27
    def default_block(self):
28
        """
29
        Return the default block.
30
        """
31
        return self._default_block
32
33
    def collect_parameters(self):
34
        """
35
        Return the default block, as all parameters will be registered to the default one.
36
        """
37
        return self._default_block
38
39
    def new_block(self, *layers, **kwargs):
40
        """
41
        Create a parameters block.
42
        :param layers: register some layers in the block
43
        :param name: specify the name of this block
44
        """
45
        from deepy.layers.block import Block
46
        block = Block(*layers, **kwargs)
47
        return block
48
49
    def var(self, tensor_type, last_dim=0, test_shape=None):
50
        """
51
        An alias of deepy.tensor.var.
52
        """
53
        from deepy.tensor import var
54
        return var(tensor_type, last_dim=last_dim, test_shape=test_shape)
55
56
    def create_vars_from_data(self, dataset, split="train"):
57
        """
58
        Create vars given a dataset and set test values.
59
        Useful when dataset is already defined.
60
        """
61
        from deepy.core.neural_var import NeuralVariable
62
        vars = []
63
        if split == "valid":
64
            data_split = dataset.valid_set()
65
        elif split == "test":
66
            data_split = dataset.test_set()
67
        else:
68
            data_split = dataset.train_set()
69
        first_data_piece = list(data_split)[0]
70
        for i, numpy_tensor in enumerate(first_data_piece):
71
            if numpy_tensor.dtype == "int64":
72
                numpy_tensor = numpy_tensor.astype("int32")
73
            if numpy_tensor.dtype == "float64":
74
                numpy_tensor = numpy_tensor.astype(env.FLOATX)
75
            type_map = {
76
                0: "scalar",
77
                1: "vector",
78
                2: "matrix",
79
                3: "tensor3",
80
                4: "tensor4",
81
                5: "tensor5",
82
            }
83
            tensor_type = type_map[numpy_tensor.ndim] if numpy_tensor.ndim in type_map else type_map[0]
84
            if numpy_tensor.dtype.kind == "i":
85
                tensor_type = "i" + tensor_type
86
            theano_tensor = getattr(TT, tensor_type)("input_{}_{}".format(i + 1, tensor_type))
87
            last_dim = numpy_tensor.shape[-1]
88
            var = NeuralVariable(theano_tensor, dim=last_dim)
89
            var.set_test_value(numpy_tensor)
90
            vars.append(var)
91
        return vars
92
93
    @neural_computation
94
    def scan(self, func, sequences=None, outputs=None, non_sequences=None, block=None, **kwargs):
95
        """
96
        A loop function, the usage is identical with the theano one.
97
        :type block: deepy.layers.Block
98
        """
99
        results, updates = Scanner(func, sequences, outputs, non_sequences, neural_computation=True, **kwargs).compute()
100
        if block and updates:
101
            if type(updates) == dict:
102
                updates = updates.items()
103
            block.register_updates(*updates)
104
        return results
105
106
    def loop(self, sequences=None, outputs=None, non_sequences=None, block=None, **kwargs):
107
        """
108
        Start a loop.
109
        Usage:
110
        ```
111
        with deepy.graph.loop(sequences={"x": x}, outputs={"o": None}) as vars:
112
            vars.o = vars.x + 1
113
        loop_outputs = deepy.graph.loop_outputs()
114
        result = loop_outputs.o
115
        ```
116
        """
117
        from loop import Loop
118
        return Loop(sequences, outputs, non_sequences, block, **kwargs)
119
120
    def get_trainer(self, model,  method='sgd', config=None, annealer=None, validator=None):
121
        """
122
        Get a trainer to optimize given model.
123
        :rtype: deepy.trainers.GeneralNeuralTrainer
124
        """
125
        from deepy.trainers import GeneralNeuralTrainer
126
        return GeneralNeuralTrainer(model, method=method, config=config, annealer=annealer, validator=validator)
127
128
    @neural_computation
129
    def shared(self, value, name=None):
130
        """
131
        Create a shared theano scalar value.
132
        """
133
        if type(value) == int:
134
            final_value = np.array(value, dtype="int32")
135
        elif type(value) == float:
136
            final_value = np.array(value, dtype=env.FLOATX)
137
        else:
138
            final_value = value
139
140
        return theano.shared(final_value, name=name)
141
142
    @neural_computation
143
    def disconnect(self, x):
144
        """
145
        Disconnect a variable from backpropagation.
146
        """
147
        return disconnected_grad(x)
148
149
    def compile(self, input_dim=0, model=None, input_tensor=None, monitors=None,
150
                 cost=None, output=None, outputs=None, blocks=None, input_vars=None, target_vars=None):
151
        from comp_graph import ComputationalGraph
152
        # Pass the arguments to `ComputationalGraph`
153
        args = [arg for arg in getargspec(GraphBuilder.compile).args if arg != "self"]
154
        arg_vals = [locals()[k] for k in args]
155
        kwargs = dict(zip(args, arg_vals))
156
        return ComputationalGraph(**kwargs)
157
158
    def fill_parameters(self, path, blocks, exclude_free_params=False, check_parameters=False):
159
        """
160
        Load parameters from file to fill all blocks sequentially.
161
        :type blocks: list of deepy.layers.Block
162
        """
163
        if not os.path.exists(path):
164
            raise Exception("model {} does not exist".format(path))
165
        # Decide which parameters to load
166
        normal_params = sum([nn.parameters for nn in blocks], [])
167
        all_params = sum([nn.all_parameters for nn in blocks], [])
168
        # Load parameters
169
        if path.endswith(".gz"):
170
            opener = gzip.open if path.lower().endswith('.gz') else open
171
            handle = opener(path, 'rb')
172
            saved_params = pickle.load(handle)
173
            handle.close()
174
            # Write parameters
175
            if len(all_params) != len(saved_params):
176
                logging.warning(
177
                    "parameters in the network: {}, parameters in the dumped model: {}".format(len(all_params),
178
                                                                                               len(saved_params)))
179
            for target, source in zip(all_params, saved_params):
180
                if not exclude_free_params or target not in normal_params:
181
                    target.set_value(source)
182
        elif path.endswith(".npz"):
183
            arrs = np.load(path)
184
            # Write parameters
185
            if len(all_params) != len(arrs.keys()):
186
                logging.warning(
187
                    "parameters in the network: {}, parameters in the dumped model: {}".format(len(all_params),
188
                                                                                               len(arrs.keys())))
189
            for target, idx in zip(all_params, range(len(arrs.keys()))):
190
                if not exclude_free_params or target not in normal_params:
191
                    source = arrs['arr_%d' % idx]
192
                    target.set_value(source)
193
        else:
194
            raise Exception("File format of %s is not supported, use '.gz' or '.npz' or '.uncompressed.gz'" % path)
195
196
197
if "graph" not in globals():
198
    graph = GraphBuilder()
199