MiniBatches.__init__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 7

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 1
c 2
b 0
f 0
dl 0
loc 7
rs 9.4285
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
5
from . import Dataset
6
import numpy as np
7
8
class MiniBatches(Dataset):
9
    """
10
    Convert data into mini-batches.
11
    """
12
13
    def __init__(self, dataset, batch_size=20, cache=True):
14
        self.origin = dataset
15
        self.size = batch_size
16
        self._cached_train_set = None
17
        self._cached_valid_set = None
18
        self._cached_test_set = None
19
        self.cache = cache
20
21
    def _yield_data(self, subset):
22
        if type(subset) != list:
23
            subset = list(subset)
24
        for i in xrange(0, len(subset), self.size):
25
            yield map(np.array, list(zip(*subset[i:i + self.size])))
26
27
    def train_set(self):
28
        if self.cache and self._cached_train_set is not None:
29
            return self._cached_train_set
30
31
        data_generator = self._yield_data(self.origin.train_set())
32
        if data_generator is None:
33
            return None
34
        if self.cache:
35
            self._cached_train_set = list(data_generator)
36
            return self._cached_train_set
37
        else:
38
            return data_generator
39
40 View Code Duplication
    def test_set(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
41
        if not self.origin.test_set():
42
            return None
43
        if self.cache and self._cached_test_set is not None:
44
            return self._cached_test_set
45
46
        data_generator = self._yield_data(self.origin.test_set())
47
        if data_generator is None:
48
            return None
49
        if self.cache:
50
            self._cached_test_set = list(data_generator)
51
            return self._cached_test_set
52
        else:
53
            return data_generator
54
55 View Code Duplication
    def valid_set(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
56
        if not self.origin.valid_set():
57
            return None
58
        if self.cache and self._cached_valid_set is not None:
59
            return self._cached_valid_set
60
61
        data_generator = self._yield_data(self.origin.valid_set())
62
        if data_generator is None:
63
            return None
64
        if self.cache:
65
            self._cached_valid_set = list(data_generator)
66
            return self._cached_valid_set
67
        else:
68
            return data_generator
69
70
    def train_size(self):
71
        train_size = self.origin.train_size()
72
        if train_size is None:
73
            train_size = len(list(self.origin.train_set()))
74
        return train_size / self.size