Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

GPUDataTransmitter.get_givens()   A

Complexity

Conditions 2

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 2
dl 0
loc 5
rs 9.4285
c 1
b 0
f 0
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
import theano
6
import theano.tensor as T
7
8
class GPUDataTransmitter(object):
9
    """
10
    Cache multiple batches on GPU.
11
    """
12
13
    def __init__(self, network, shapes, dtypes, cache_num=10):
14
        self.network = network
15
        self.shapes = shapes
16
        self.cache_num = cache_num
17
        self.dtypes = dtypes
18
        self.all_variables = self.network.input_variables + self.network.target_variables
19
        if len(self.all_variables) != len(shapes):
20
            raise Exception("The number of network variables is not identical with shapes")
21
        self.iterator = T.iscalar("i")
22
        self.gpu_caches = []
23
        self.cpu_datas = []
24
        for shape, dtype in zip(self.shapes, self.dtypes):
25
            cache_shape = [cache_num] + shape
26
            cache = theano.shared(np.zeros(cache_shape, dtype=dtype))
27
            self.gpu_caches.append(cache)
28
29
    def get_givens(self):
30
        givens = {}
31
        for var, cache in zip(self.all_variables, self.gpu_caches):
32
            givens[var] = cache[self.iterator]
33
        return givens
34
35
    def get_iterator(self):
36
        return self.iterator
37
38
    def transmit(self, *data_list):
39
        for cache, data in zip(self.gpu_caches, data_list):
40
            cache.set_value(data, borrow=True)
41
42
    def wrap(self, data_source):
43
        if not self.cpu_datas:
44
            # Load data source to CPU memory (all)
45
            datas = []
46
            for _ in self.shapes:
47
                datas.append([])
48
            for data_tuple in data_source:
49
                for i in range(len(data_tuple)):
50
                    datas[i].append(data_tuple[i])
51
            for i in range(len(datas)):
52
                datas[i] = datas[i][:-1]
53
            # Convert to numpy array
54
            for data, dtype in zip(datas, self.dtypes):
55
                self.cpu_datas.append(np.array(data, dtype=dtype))
56
        data_len = self.cpu_datas[0].shape[0]
57
        for i in xrange(0, data_len, self.cache_num):
58
            if i + self.cache_num > data_len:
59
                continue
60
            transmit_datas = [data[i:i+self.cache_num] for data in self.cpu_datas]
61
            self.transmit(*transmit_datas)
62
            for n in range(self.cache_num):
63
                yield [n]
64