Completed
Push — master ( 39f179...514f8f )
by Tinghui
01:17
created

BatchSequenceInjector.to_sequence()   C

Complexity

Conditions 7

Size

Total Lines 32

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 7
c 1
b 0
f 0
dl 0
loc 32
rs 5.5
1
import math
2
import logging
3
import numpy as np
4
5
logger = logging.getLogger(__name__)
6
7
8
class BatchInjector:
9
    """Retrieving dataset values in batches
10
11
    Args:
12
        data_x (:obj:`numpy.ndarray`): Input feature array.
13
        data_y (:obj:`numpy.ndarray`): Input label array.
14
        batch_size (:obj:`int`): Batch size.
15
        num_batches (:obj:`int`): The number of batches in the input data.
16
17
    Attributes:
18
        size (:obj:`int`): Number of input vectors.
19
        batch_size (:obj:`int`): Batch size.
20
        num_batches (:obj:`int`): Number of batches in the input data.
21
        num_epochs (:obj:`int`): Number of epoch of current iteration.
22
        cur_batch (:obj:`int`): Current batch index.
23
        data_x (:obj:`numpy.ndarray`): Reference to input feature array.
24
        data_y (:obj:`numpy.ndarray`): Reference to input label array.s
25
    """
26 View Code Duplication
    def __init__(self, data_x, data_y=None, batch_size=-1, num_batches=-1):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
27
        self.size = data_x.shape[0]
28
        if 0 < batch_size <= self.size:
29
            self.batch_size = batch_size
30
            self.num_batches = math.floor(self.size / self.batch_size)
31
        elif num_batches > 0:
32
            self.batch_size = math.floor(self.size / num_batches)
33
            self.num_batches = num_batches
34
        else:
35
            raise ValueError('Invalid batch_size or num_batches.')
36
        self.num_epochs = 0
37
        self.cur_batch = 0
38
        self.data_x = data_x
39
        self.data_y = data_y
40
        if data_y is not None:
41
            if self.data_x.shape[0] != self.data_y.shape[0]:
42
                raise ValueError('data_x, data_y provided have different number of rows.')
43
44
    def next_batch(self):
45
        """Get Next Batch
46
        """
47
        if self.cur_batch == self.num_batches - 1:
48
            start = self.batch_size * self.cur_batch
49
            end = self.size
50
            self.cur_batch = 0
51
            self.num_epochs += 1
52
        else:
53
            start = self.batch_size * self.cur_batch
54
            end = start + self.batch_size
55
            self.cur_batch += 1
56
        if self.data_y is None:
57
            return self.data_x[start:end, :]
58
        else:
59
            return self.data_x[start:end, :], self.data_y[start:end, :]
60
61
    def reset(self):
62
        """Reset all counters
63
        """
64
        self.cur_batch = 0
65
        self.num_epochs = 0
66
67
68
class BatchSequenceInjector:
69
    """Retrieving dataset values in batches and form a sequence of events
70
71
    Args:
72
        data_x (:obj:`numpy.ndarray`): Input feature array.
73
        data_y (:obj:`numpy.ndarray`): Input label array.
74
        seq_len (:obj:`int`): Length of sequence.
75
        batch_size (:obj:`int`): Batch size.
76
        num_batches (:obj:`int`): The number of batches in the input data.
77
78
    Attributes:
79
        seq_len (:obj:`int`): Length of sequence.
80
        size (:obj:`int`): Number of input vectors.
81
        batch_size (:obj:`int`): Batch size.
82
        num_batches (:obj:`int`): Number of batches in the input data.
83
        num_epochs (:obj:`int`): Number of epoch of current iteration.
84
        cur_batch (:obj:`int`): Current batch index.
85
        data_x (:obj:`numpy.ndarray`): Reference to input feature array.
86
        data_y (:obj:`numpy.ndarray`): Reference to input label array.s
87
    """
88 View Code Duplication
    def __init__(self, data_x, data_y=None, seq_len=100, batch_size=-1, num_batches=-1):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
89
        self.seq_len = seq_len
90
        self.size = data_x.shape[0] - seq_len
91
        if 0 < batch_size <= self.size:
92
            self.batch_size = batch_size
93
            self.num_batches = math.floor(self.size / self.batch_size)
94
        elif num_batches > 0:
95
            self.batch_size = math.floor(self.size / num_batches)
96
            self.num_batches = num_batches
97
        else:
98
            raise ValueError('Invalid batch_size or num_batches.')
99
        self.num_epochs = 0
100
        self.cur_batch = 0
101
        self.data_x = data_x
102
        self.data_y = data_y
103
        if data_y is not None:
104
            if self.data_x.shape[0] != self.data_y.shape[0]:
105
                raise ValueError('data_x, data_y provided have different number of rows.')
106
107
    def next_batch(self):
108
        """Get Next Batch
109
        """
110
        if self.cur_batch == self.num_batches - 1:
111
            start = self.batch_size * self.cur_batch
112
            end = self.size
113
            self.cur_batch = 0
114
            self.num_epochs += 1
115
        else:
116
            start = self.batch_size * self.cur_batch
117
            end = start + self.batch_size
118
            self.cur_batch += 1
119
        return self.to_sequence(self.seq_len, self.data_x, self.data_y, start, end)
120
121
    def reset(self):
122
        """Reset all counters
123
        """
124
        self.cur_batch = 0
125
        self.num_epochs = 0
126
127
    @staticmethod
128
    def to_sequence(seq_len, x, y=None, start=None, end=None):
129
        """Turn feature array as a sequence array where each new feature contains seq_len number of original features.
130
131
        Args:
132
            seq_len (:obj:`int`): Length of the sequence.
133
            x (:obj:`numpy.ndarray`): Feature array, with shape (num_samples, num_features).
134
            y (:obj:`numpy.ndarray`): Label array, with shape (num_samples. num_classes).
135
            start (:obj:`int`): Start index.
136
            end (:obj:`int`): End index
137
138
        Returns:
139
            (seq_x, seq_y) if y is provided, or seq_x if y is not provided.
140
            seq_x is a numpy array of shape (num_samples, seq_len, num_features), and seq_y is a numpy array
141
            of shape (num_samples, num_classes).
142
            num_samples is bounded by the value of start and end.
143
            If start or end are not specified, the code will use the full data provided, so that the
144
            array returned has (num_samples - seq_len) of samples.
145
        """
146
        if start is None or end is None:
147
            start = 0
148
            end = x.shape[0] - seq_len
149
        if (start+seq_len) > x.shape[0] or (end+seq_len) > x.shape[0]:
150
            logger.error('start/end out of bound.')
151
            return None
152
        batch_x = np.zeros((end - start, seq_len, x.shape[1]), np.float32)
153
        for i in range(start, end):
154
            batch_x[i-start, :, :] = x[i:i+seq_len, :]
155
        if y is None:
156
            return batch_x
157
        else:
158
            return batch_x, y[start+seq_len:end+seq_len, :]