Completed
Push — master ( 444ad5...bcb8f4 )
by Tinghui
01:05
created

BatchSequenceInjector.to_sequence()   F

Complexity

Conditions 9

Size

Total Lines 39

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 9
c 2
b 0
f 0
dl 0
loc 39
rs 3
1
import math
2
import random
3
import logging
4
import collections
5
import numpy as np
6
7
logger = logging.getLogger(__name__)
8
9
10
class BatchInjector:
11
    """Retrieving dataset values in batches
12
13
    Args:
14
        data_x (:obj:`numpy.ndarray`): Input feature array.
15
        data_y (:obj:`numpy.ndarray`): Input label array.
16
        batch_size (:obj:`int`): Batch size.
17
        num_batches (:obj:`int`): The number of batches in the input data.
18
19
    Attributes:
20
        size (:obj:`int`): Number of input vectors.
21
        batch_size (:obj:`int`): Batch size.
22
        num_batches (:obj:`int`): Number of batches in the input data.
23
        num_epochs (:obj:`int`): Number of epoch of current iteration.
24
        cur_batch (:obj:`int`): Current batch index.
25
        data_x (:obj:`numpy.ndarray`): Reference to input feature array.
26
        data_y (:obj:`numpy.ndarray`): Reference to input label array.s
27
    """
28 View Code Duplication
    def __init__(self, data_x, data_y=None, batch_size=-1, num_batches=-1):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
29
        self.size = data_x.shape[0]
30
        if 0 < batch_size <= self.size:
31
            self.batch_size = batch_size
32
            self.num_batches = math.floor(self.size / self.batch_size)
33
        elif num_batches > 0:
34
            self.batch_size = math.floor(self.size / num_batches)
35
            self.num_batches = num_batches
36
        else:
37
            raise ValueError('Invalid batch_size or num_batches.')
38
        self.num_epochs = 0
39
        self.cur_batch = 0
40
        self.data_x = data_x
41
        self.data_y = data_y
42
        if data_y is not None:
43
            if self.data_x.shape[0] != self.data_y.shape[0]:
44
                raise ValueError('data_x, data_y provided have different number of rows.')
45
46
    def next_batch(self):
47
        """Get Next Batch
48
        """
49
        if self.cur_batch == self.num_batches - 1:
50
            start = self.batch_size * self.cur_batch
51
            end = self.size
52
            self.cur_batch = 0
53
            self.num_epochs += 1
54
        else:
55
            start = self.batch_size * self.cur_batch
56
            end = start + self.batch_size
57
            self.cur_batch += 1
58
        if self.data_y is None:
59
            return self.data_x[start:end, :]
60
        else:
61
            return self.data_x[start:end, :], self.data_y[start:end, :]
62
63
    def reset(self):
64
        """Reset all counters
65
        """
66
        self.cur_batch = 0
67
        self.num_epochs = 0
68
69
70
class BatchSequenceInjector:
71
    """Retrieving dataset values in batches and form a sequence of events
72
73
    Args:
74
        data_x (:obj:`numpy.ndarray`): Input feature array.
75
        data_y (:obj:`numpy.ndarray`): Input label array.
76
        seq_len (:obj:`int`): Length of sequence.
77
        batch_size (:obj:`int`): Batch size.
78
        num_batches (:obj:`int`): The number of batches in the input data.
79
80
    Attributes:
81
        length (:obj:`int`): Length of sequence.
82
        size (:obj:`int`): Number of input vectors.
83
        batch_size (:obj:`int`): Batch size.
84
        num_batches (:obj:`int`): Number of batches in the input data.
85
        num_epochs (:obj:`int`): Number of epoch of current iteration.
86
        cur_batch (:obj:`int`): Current batch index.
87
        data_x (:obj:`numpy.ndarray`): Reference to input feature array.
88
        data_y (:obj:`numpy.ndarray`): Reference to input label array.s
89
    """
90 View Code Duplication
    def __init__(self, data_x, data_y=None, length=100, batch_size=-1, num_batches=-1, with_seq=False):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
91
        self.with_seq = with_seq
92
        self.length = length
93
        self.size = data_x.shape[0] - length
94
        if 0 < batch_size <= self.size:
95
            self.batch_size = batch_size
96
            self.num_batches = math.floor(self.size / self.batch_size)
97
        elif num_batches > 0:
98
            self.batch_size = math.floor(self.size / num_batches)
99
            self.num_batches = num_batches
100
        else:
101
            raise ValueError('Invalid batch_size or num_batches.')
102
        self.num_epochs = 0
103
        self.cur_batch = 0
104
        self.data_x = data_x
105
        self.data_y = data_y
106
        if data_y is not None:
107
            if self.data_x.shape[0] != self.data_y.shape[0]:
108
                raise ValueError('data_x, data_y provided have different number of rows.')
109
110
    def next_batch(self, skip=1):
111
        """Get Next Batch
112
        """
113
        self.cur_batch += skip-1
114
        if self.cur_batch > self.num_batches - 1:
115
            self.cur_batch = 0
116
            self.num_epochs += 1
117
        if self.cur_batch == self.num_batches - 1:
118
            start = self.batch_size * self.cur_batch
119
            end = self.size
120
            self.cur_batch = 0
121
            self.num_epochs += 1
122
        else:
123
            start = self.batch_size * self.cur_batch
124
            end = start + self.batch_size
125
            self.cur_batch += 1
126
        return self.to_sequence(self.length, self.data_x, self.data_y, start, end, with_seq=self.with_seq)
127
128
    def reset(self):
129
        """Reset all counters
130
        """
131
        self.cur_batch = 0
132
        self.num_epochs = 0
133
134
    @staticmethod
135
    def to_sequence(length, x, y=None, start=None, end=None, with_seq=False):
136
        """Turn feature array as a sequence array where each new feature contains seq_len number of original features.
137
138
        Args:
139
            length (:obj:`int`): Length of the sequence.
140
            x (:obj:`numpy.ndarray`): Feature array, with shape (num_samples, num_features).
141
            y (:obj:`numpy.ndarray`): Label array, with shape (num_samples. num_classes).
142
            start (:obj:`int`): Start index.
143
            end (:obj:`int`): End index
144
145
        Returns:
146
            (seq_x, seq_y) if y is provided, or seq_x if y is not provided.
147
            seq_x is a numpy array of shape (num_samples, seq_len, num_features), and seq_y is a numpy array
148
            of shape (num_samples, num_classes).
149
            num_samples is bounded by the value of start and end.
150
            If start or end are not specified, the code will use the full data provided, so that the
151
            array returned has (num_samples - seq_len) of samples.
152
        """
153
        if start is None or end is None:
154
            start = 0
155
            end = x.shape[0] - length
156
        if (start+length) > x.shape[0] or (end+length) > x.shape[0]:
157
            logger.error('start/end out of bound.')
158
            return None
159
        batch_x = np.zeros((end - start, length, x.shape[1]), np.float32)
160
        for i in range(start, end):
161
            batch_x[i-start, :, :] = x[i:i + length, :]
162
        return_tuple = tuple([batch_x])
163
        if y is not None:
164
            batch_y = np.zeros((end - start, length, y.shape[1]), np.float32)
165
            for i in range(start, end):
166
                batch_y[i-start, :, :] = y[i:i + length, :]
167
            return_tuple += tuple([batch_y])
168
        if with_seq:
169
            seq_ar = np.zeros((end - start,), np.float32)
170
            seq_ar[:] = length
171
            return_tuple += tuple([seq_ar])
172
        return return_tuple
173
174
175
class SkipGramInjector:
176
    """Skip-Gram Batch Injector
177
178
    It generates a k-skip-2-gram sets based on input sequence
179
180
    Args:
181
        data_x (:obj:`np.ndarray`): 1D array of integer index.
182
        batch_size (:obj:`int`): Size of each batch to be generated.
183
        num_skips (:obj:`int`): How many times to re-use an input to generate a label.
184
        skip_window (:obj:`int`): How many items to consider left or right.
185
186
    Attributes:
187
        data_x (:obj:`np.ndarray`): 1D array of integer index.
188
        batch_size (:obj:`int`): Size of each batch to be generated.
189
        num_skips (:obj:`int`): How many times to re-use an input to generate a label.
190
        skip_window (:obj:`int`): How many items to consider left or right.
191
        data_index (:obj:`int`): Current index used to generate next batch.
192
    """
193
    def __init__(self, data_x, batch_size, num_skips, skip_window):
194
        assert batch_size % num_skips == 0
195
        assert num_skips <= 2 * skip_window
196
        self.data_x = data_x
197
        self.batch_size = batch_size
198
        self.num_skips = num_skips
199
        self.skip_window = skip_window
200
        self.data_index = 0
201
202
    def next_batch(self):
203
        """Get Next Batch
204
        """
205
        # Initialize batch and label array
206
        batch = np.ndarray(shape=(self.batch_size), dtype=np.int32)
207
        labels = np.ndarray(shape=(self.batch_size, 1), dtype=np.int32)
208
        # span is the size of window we are sampling from
209
        span = 2 * self.skip_window + 1  # [ skip_window target skip_window ]
210
        # Add data in the buffer to a queue
211
        buffer = collections.deque(maxlen=span)
212
        for _ in range(span):
213
            buffer.append(self.data_x[self.data_index])
214
            self.data_index = (self.data_index + 1) % len(self.data_x)
215
        # Now, populate the k-skip-2-gram data-label pair with random sampling
216
        for i in range(self.batch_size // self.num_skips):
217
            target = self.skip_window  # target label at the center of the buffer
218
            targets_to_avoid = [self.skip_window]
219
            for j in range(self.num_skips):
220
                while target in targets_to_avoid:
221
                    target = random.randint(0, span - 1)
222
                targets_to_avoid.append(target)
223
                batch[i * self.num_skips + j] = buffer[self.skip_window]
224
                labels[i * self.num_skips + j, 0] = buffer[target]
225
            buffer.append(self.data_x[self.data_index])
226
            self.data_index = (self.data_index + 1) % len(self.data_x)
227
        # Backtrack a little bit to avoid skipping words in the end of a batch
228
        self.data_index = (self.data_index + len(self.data_x) - span) % len(self.data_x)
229
        return batch, labels
230