Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

GPUDataTransmitter.wrap()   D

Complexity

Conditions 11

Size

Total Lines 22

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 11
dl 0
loc 22
rs 4.0714
c 1
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like GPUDataTransmitter.wrap() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
import theano
6
import theano.tensor as T
7
8
class GPUDataTransmitter(object):
9
    """
10
    Cache multiple batches on GPU.
11
    """
12
13
    def __init__(self, network, shapes, dtypes, cache_num=10):
14
        self.network = network
15
        self.shapes = shapes
16
        self.cache_num = cache_num
17
        self.dtypes = dtypes
18
        self.all_variables = self.network.input_variables + self.network.target_variables
19
        if len(self.all_variables) != len(shapes):
20
            raise Exception("The number of network variables is not identical with shapes")
21
        self.iterator = T.iscalar("i")
22
        self.gpu_caches = []
23
        self.cpu_datas = []
24
        for shape, dtype in zip(self.shapes, self.dtypes):
25
            cache_shape = [cache_num] + shape
26
            cache = theano.shared(np.zeros(cache_shape, dtype=dtype))
27
            self.gpu_caches.append(cache)
28
29
    def get_givens(self):
30
        givens = {}
31
        for var, cache in zip(self.all_variables, self.gpu_caches):
32
            givens[var] = cache[self.iterator]
33
        return givens
34
35
    def get_iterator(self):
36
        return self.iterator
37
38
    def transmit(self, *data_list):
39
        for cache, data in zip(self.gpu_caches, data_list):
40
            cache.set_value(data, borrow=True)
41
42
    def wrap(self, data_source):
43
        if not self.cpu_datas:
44
            # Load data source to CPU memory (all)
45
            datas = []
46
            for _ in self.shapes:
47
                datas.append([])
48
            for data_tuple in data_source:
49
                for i in range(len(data_tuple)):
50
                    datas[i].append(data_tuple[i])
51
            for i in range(len(datas)):
52
                datas[i] = datas[i][:-1]
53
            # Convert to numpy array
54
            for data, dtype in zip(datas, self.dtypes):
55
                self.cpu_datas.append(np.array(data, dtype=dtype))
56
        data_len = self.cpu_datas[0].shape[0]
57
        for i in xrange(0, data_len, self.cache_num):
58
            if i + self.cache_num > data_len:
59
                continue
60
            transmit_datas = [data[i:i+self.cache_num] for data in self.cpu_datas]
61
            self.transmit(*transmit_datas)
62
            for n in range(self.cache_num):
63
                yield [n]
64