|
@@ 88-105 (lines=18) @@
|
| 85 |
|
data_x (:obj:`numpy.ndarray`): Reference to input feature array. |
| 86 |
|
data_y (:obj:`numpy.ndarray`): Reference to input label array.s |
| 87 |
|
""" |
| 88 |
|
def __init__(self, data_x, data_y=None, seq_len=100, batch_size=-1, num_batches=-1): |
| 89 |
|
self.seq_len = seq_len |
| 90 |
|
self.size = data_x.shape[0] - seq_len |
| 91 |
|
if 0 < batch_size <= self.size: |
| 92 |
|
self.batch_size = batch_size |
| 93 |
|
self.num_batches = math.floor(self.size / self.batch_size) |
| 94 |
|
elif num_batches > 0: |
| 95 |
|
self.batch_size = math.floor(self.size / num_batches) |
| 96 |
|
self.num_batches = num_batches |
| 97 |
|
else: |
| 98 |
|
raise ValueError('Invalid batch_size or num_batches.') |
| 99 |
|
self.num_epochs = 0 |
| 100 |
|
self.cur_batch = 0 |
| 101 |
|
self.data_x = data_x |
| 102 |
|
self.data_y = data_y |
| 103 |
|
if data_y is not None: |
| 104 |
|
if self.data_x.shape[0] != self.data_y.shape[0]: |
| 105 |
|
raise ValueError('data_x, data_y provided have different number of rows.') |
| 106 |
|
|
| 107 |
|
def next_batch(self): |
| 108 |
|
"""Get Next Batch |
|
@@ 26-42 (lines=17) @@
|
| 23 |
|
data_x (:obj:`numpy.ndarray`): Reference to input feature array. |
| 24 |
|
data_y (:obj:`numpy.ndarray`): Reference to input label array.s |
| 25 |
|
""" |
| 26 |
|
def __init__(self, data_x, data_y=None, batch_size=-1, num_batches=-1): |
| 27 |
|
self.size = data_x.shape[0] |
| 28 |
|
if 0 < batch_size <= self.size: |
| 29 |
|
self.batch_size = batch_size |
| 30 |
|
self.num_batches = math.floor(self.size / self.batch_size) |
| 31 |
|
elif num_batches > 0: |
| 32 |
|
self.batch_size = math.floor(self.size / num_batches) |
| 33 |
|
self.num_batches = num_batches |
| 34 |
|
else: |
| 35 |
|
raise ValueError('Invalid batch_size or num_batches.') |
| 36 |
|
self.num_epochs = 0 |
| 37 |
|
self.cur_batch = 0 |
| 38 |
|
self.data_x = data_x |
| 39 |
|
self.data_y = data_y |
| 40 |
|
if data_y is not None: |
| 41 |
|
if self.data_x.shape[0] != self.data_y.shape[0]: |
| 42 |
|
raise ValueError('data_x, data_y provided have different number of rows.') |
| 43 |
|
|
| 44 |
|
def next_batch(self): |
| 45 |
|
"""Get Next Batch |