1
|
|
|
#!/usr/bin/env python |
2
|
|
|
# -*- coding: utf-8 -*- |
3
|
|
|
|
4
|
|
|
|
5
|
|
|
from . import Dataset |
6
|
|
|
import numpy as np |
7
|
|
|
|
8
|
|
|
class MiniBatches(Dataset): |
9
|
|
|
""" |
10
|
|
|
Convert data into mini-batches. |
11
|
|
|
""" |
12
|
|
|
|
13
|
|
|
def __init__(self, dataset, batch_size=20, cache=True): |
14
|
|
|
self.origin = dataset |
15
|
|
|
self.size = batch_size |
16
|
|
|
self._cached_train_set = None |
17
|
|
|
self._cached_valid_set = None |
18
|
|
|
self._cached_test_set = None |
19
|
|
|
self.cache = cache |
20
|
|
|
|
21
|
|
|
def _yield_data(self, subset): |
22
|
|
|
if type(subset) != list: |
23
|
|
|
subset = list(subset) |
24
|
|
|
for i in xrange(0, len(subset), self.size): |
25
|
|
|
yield map(np.array, list(zip(*subset[i:i + self.size]))) |
26
|
|
|
|
27
|
|
|
def train_set(self): |
28
|
|
|
if self.cache and self._cached_train_set is not None: |
29
|
|
|
return self._cached_train_set |
30
|
|
|
|
31
|
|
|
data_generator = self._yield_data(self.origin.train_set()) |
32
|
|
|
if data_generator is None: |
33
|
|
|
return None |
34
|
|
|
if self.cache: |
35
|
|
|
self._cached_train_set = list(data_generator) |
36
|
|
|
return self._cached_train_set |
37
|
|
|
else: |
38
|
|
|
return data_generator |
39
|
|
|
|
40
|
|
View Code Duplication |
def test_set(self): |
|
|
|
|
41
|
|
|
if not self.origin.test_set(): |
42
|
|
|
return None |
43
|
|
|
if self.cache and self._cached_test_set is not None: |
44
|
|
|
return self._cached_test_set |
45
|
|
|
|
46
|
|
|
data_generator = self._yield_data(self.origin.test_set()) |
47
|
|
|
if data_generator is None: |
48
|
|
|
return None |
49
|
|
|
if self.cache: |
50
|
|
|
self._cached_test_set = list(data_generator) |
51
|
|
|
return self._cached_test_set |
52
|
|
|
else: |
53
|
|
|
return data_generator |
54
|
|
|
|
55
|
|
View Code Duplication |
def valid_set(self): |
|
|
|
|
56
|
|
|
if not self.origin.valid_set(): |
57
|
|
|
return None |
58
|
|
|
if self.cache and self._cached_valid_set is not None: |
59
|
|
|
return self._cached_valid_set |
60
|
|
|
|
61
|
|
|
data_generator = self._yield_data(self.origin.valid_set()) |
62
|
|
|
if data_generator is None: |
63
|
|
|
return None |
64
|
|
|
if self.cache: |
65
|
|
|
self._cached_valid_set = list(data_generator) |
66
|
|
|
return self._cached_valid_set |
67
|
|
|
else: |
68
|
|
|
return data_generator |
69
|
|
|
|
70
|
|
|
def train_size(self): |
71
|
|
|
train_size = self.origin.train_size() |
72
|
|
|
if train_size is None: |
73
|
|
|
train_size = len(list(self.origin.train_set())) |
74
|
|
|
return train_size / self.size |