| Total Complexity | 6 |
| Total Lines | 22 |
| Duplicated Lines | 0 % |
| 1 | #!/usr/bin/env python |
||
| 11 | class SequentialMiniBatches(MiniBatches): |
||
| 12 | """ |
||
| 13 | Mini batch class for sequential data. |
||
| 14 | """ |
||
| 15 | |||
| 16 | def __init__(self, dataset, batch_size=20, padding_length=-1, fix_batch_size=False): |
||
| 17 | super(SequentialMiniBatches, self).__init__(dataset, batch_size=batch_size) |
||
| 18 | self.padding_length = padding_length |
||
| 19 | self._fix_batch_size = fix_batch_size |
||
| 20 | |||
| 21 | def _yield_data(self, subset): |
||
| 22 | for i in xrange(0, len(subset), self.size): |
||
| 23 | x_set, y_set = [], [] |
||
| 24 | batch = pad_dataset(subset[i:i + self.size], PADDING_SIDE, self.padding_length) |
||
| 25 | for x, y in batch: |
||
| 26 | x_set.append(x) |
||
| 27 | y_set.append(y) |
||
| 28 | x_set = np.array(x_set) |
||
| 29 | y_set = np.array(y_set) |
||
| 30 | if self._fix_batch_size and x_set.shape[0] != self.size: |
||
| 31 | continue |
||
| 32 | yield x_set, y_set |