@@ 90-108 (lines=19) @@ | ||
87 | data_x (:obj:`numpy.ndarray`): Reference to input feature array. |
|
88 | data_y (:obj:`numpy.ndarray`): Reference to input label array.s |
|
89 | """ |
|
90 | def __init__(self, data_x, data_y=None, length=100, batch_size=-1, num_batches=-1, with_seq=False): |
|
91 | self.with_seq = with_seq |
|
92 | self.length = length |
|
93 | self.size = data_x.shape[0] - length |
|
94 | if 0 < batch_size <= self.size: |
|
95 | self.batch_size = batch_size |
|
96 | self.num_batches = math.floor(self.size / self.batch_size) |
|
97 | elif num_batches > 0: |
|
98 | self.batch_size = math.floor(self.size / num_batches) |
|
99 | self.num_batches = num_batches |
|
100 | else: |
|
101 | raise ValueError('Invalid batch_size or num_batches.') |
|
102 | self.num_epochs = 0 |
|
103 | self.cur_batch = 0 |
|
104 | self.data_x = data_x |
|
105 | self.data_y = data_y |
|
106 | if data_y is not None: |
|
107 | if self.data_x.shape[0] != self.data_y.shape[0]: |
|
108 | raise ValueError('data_x, data_y provided have different number of rows.') |
|
109 | ||
110 | def next_batch(self, skip=1): |
|
111 | """Get Next Batch |
|
@@ 28-44 (lines=17) @@ | ||
25 | data_x (:obj:`numpy.ndarray`): Reference to input feature array. |
|
26 | data_y (:obj:`numpy.ndarray`): Reference to input label array.s |
|
27 | """ |
|
28 | def __init__(self, data_x, data_y=None, batch_size=-1, num_batches=-1): |
|
29 | self.size = data_x.shape[0] |
|
30 | if 0 < batch_size <= self.size: |
|
31 | self.batch_size = batch_size |
|
32 | self.num_batches = math.floor(self.size / self.batch_size) |
|
33 | elif num_batches > 0: |
|
34 | self.batch_size = math.floor(self.size / num_batches) |
|
35 | self.num_batches = num_batches |
|
36 | else: |
|
37 | raise ValueError('Invalid batch_size or num_batches.') |
|
38 | self.num_epochs = 0 |
|
39 | self.cur_batch = 0 |
|
40 | self.data_x = data_x |
|
41 | self.data_y = data_y |
|
42 | if data_y is not None: |
|
43 | if self.data_x.shape[0] != self.data_y.shape[0]: |
|
44 | raise ValueError('data_x, data_y provided have different number of rows.') |
|
45 | ||
46 | def next_batch(self): |
|
47 | """Get Next Batch |