Completed
Push — master ( f73e69...91b7c0 )
by Raphael
01:35
created

SequentialMiniBatches._yield_data()   B

Complexity

Conditions 5

Size

Total Lines 12

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 5
dl 0
loc 12
rs 8.5454
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
5
from . import MiniBatches
6
from padding import pad_dataset
7
import numpy as np
8
9
PADDING_SIDE = "right"
10
11
class SequentialMiniBatches(MiniBatches):
12
    """
13
    Mini batch class for sequential data.
14
    """
15
16
    def __init__(self, dataset, batch_size=20, padding_length=-1, fix_batch_size=False):
17
        super(SequentialMiniBatches, self).__init__(dataset, batch_size=batch_size)
18
        self.padding_length = padding_length
19
        self._fix_batch_size = fix_batch_size
20
21
    def _yield_data(self, subset):
22
        for i in xrange(0, len(subset), self.size):
23
            x_set, y_set = [], []
24
            batch = pad_dataset(subset[i:i + self.size], PADDING_SIDE, self.padding_length)
25
            for x, y in batch:
26
                x_set.append(x)
27
                y_set.append(y)
28
            x_set = np.array(x_set)
29
            y_set = np.array(y_set)
30
            if self._fix_batch_size and x_set.shape[0] != self.size:
31
                continue
32
            yield x_set, y_set