Completed
Push — master ( b93253...24f24c )
by David
03:31
created

blocks/log/sqlite.py (10 issues)

1
"""SQLite backend for the main loop log."""
0 ignored issues
show
There seems to be a cyclic import (blocks.bricks.base -> blocks.graph -> blocks.graph.bn -> blocks.filter).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.architectures -> blocks.bricks.simple -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.misc -> blocks.bricks.simple -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.architectures -> blocks.bricks.recurrent.base -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.interfaces -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.recurrent -> blocks.bricks.recurrent.misc -> blocks.bricks.parallel -> blocks.bricks.simple -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
There seems to be a cyclic import (blocks.bricks -> blocks.bricks.bn -> blocks.bricks.sequences -> blocks.bricks.simple -> blocks.bricks.wrappers -> blocks.bricks.base -> blocks.graph -> blocks.graph.bn).

Cyclic imports may cause partly loaded modules to be returned. This might lead to unexpected runtime behavior which is hard to debug.

Loading history...
2
import sqlite3
3
import warnings
4
from collections import MutableMapping, Mapping
5
from operator import itemgetter
6
7
import numpy
8
import six
9
from six.moves import cPickle, map
10
11
from blocks.config import config
12
from .log import TrainingLogBase
13
14
15
ANCESTORS_QUERY = """
16
WITH parents (parent, child) AS (
17
    SELECT uuid, value FROM status
18
    WHERE key = 'resumed_from' AND uuid = ?
19
    UNION ALL
20
    SELECT uuid, value FROM status
21
    INNER JOIN parents ON status.uuid = parents.child
22
    WHERE key = 'resumed_from'
23
),
24
ancestors AS (SELECT parent FROM parents)
25
"""
26
27
LARGE_BLOB_WARNING = """
28
29
A {} object of {} bytes was stored in the SQLite database. SQLite natively \
30
only supports numbers and text. Other objects will be pickled before being \
31
saved. For large objects, this can be slow and degrade performance of the \
32
database."""
33
34
35
def adapt_obj(obj):
36
    """Binarize objects to be stored in an SQLite database.
37
38
    Parameters
39
    ----------
40
    obj : object
41
        Any picklable object.
42
43
    Returns
44
    -------
45
    blob : memoryview
46
        A buffer (Python 2) or memoryview (Python 3) of the pickled object
47
        that can be stored as a BLOB in an SQLite database.
48
49
    """
50
    blob = sqlite3.Binary(cPickle.dumps(obj))
51
    if len(blob) > config.max_blob_size:
52
        warnings.warn('large objects stored in SQLite' +
53
                      LARGE_BLOB_WARNING.format(type(obj), len(blob)))
54
        # Prevent the warning with variable message from repeating
55
        warnings.filterwarnings('ignore', 'large objects .*')
56
    return blob
57
58
59
def adapt_ndarray(obj):
60
    """Convert NumPy scalars to floats before storing in SQLite.
61
62
    This makes it easier to inspect the database, and speeds things up.
63
64
    Parameters
65
    ----------
66
    obj : ndarray
67
        A NumPy array.
68
69
    Returns
70
    -------
71
    float or memoryview
72
        If the array was a scalar, it returns a floating point number.
73
        Otherwise it binarizes the NumPy array using :func:`adapt_obj`
74
75
    """
76
    if obj.ndim == 0:
77
        return float(obj)
78
    else:
79
        return adapt_obj(obj)
80
81
82
def _get_row(row, key):
83
    """Handle the returned row e.g. unpickle if needed."""
84
    if row is not None:
85
        value = row[0]
86
        # Resumption UUIDs are stored as bytes and should not be unpickled
87
        if (isinstance(value, (sqlite3.Binary, bytes)) and
88
                key != 'resumed_from'):
89
            value = cPickle.loads(bytes(value))
90
        return value
91
    raise KeyError(key)
92
93
94
def _register_adapter(value, key):
95
    """Register an adapter if the type of value is unknown."""
96
    # Assuming no storage of non-simple types on channel 'resumed_from'
97
    if (not isinstance(value, (type(None), int, float, six.string_types,
98
                               bytes, numpy.ndarray)) and
99
            key != 'resumed_from'):
100
        sqlite3.register_adapter(type(value), adapt_obj)
101
102
103
class SQLiteLog(TrainingLogBase, Mapping):
104
    r"""Training log using SQLite as a backend.
105
106
    Parameters
107
    ----------
108
    database : str, optional
109
        The database (file) to connect to. Can also be `:memory:`. See
110
        :func:`sqlite3.connect` for details. Uses `config.sqlite_database`
111
        by default.
112
    \*\*kwargs
113
        Arguments to pass to :class:`TrainingLogBase`
114
115
    """
116
    def __init__(self, database=None, **kwargs):
117
        if database is None:
118
            database = config.sqlite_database
119
        self.database = database
120
        self.conn = sqlite3.connect(database)
121
        sqlite3.register_adapter(numpy.ndarray, adapt_ndarray)
122
        with self.conn:
123
            self.conn.execute("""CREATE TABLE IF NOT EXISTS entries (
124
                                   uuid TEXT NOT NULL,
125
                                   time INT NOT NULL,
126
                                   "key" TEXT NOT NULL,
127
                                   value,
128
                                   PRIMARY KEY(uuid, time, "key")
129
                                 );""")
130
            self.conn.execute("""CREATE TABLE IF NOT EXISTS status (
131
                                   uuid TEXT NOT NULL,
132
                                   "key" text NOT NULL,
133
                                   value,
134
                                   PRIMARY KEY(uuid, "key")
135
                                 );""")
136
        self.status = SQLiteStatus(self)
137
        super(SQLiteLog, self).__init__(**kwargs)
138
139
    @property
140
    def conn(self):
141
        if not hasattr(self, '_conn'):
142
            self._conn = sqlite3.connect(self.database)
143
        return self._conn
144
145
    @conn.setter
146
    def conn(self, value):
147
        self._conn = value
148
149
    def __getstate__(self):
150
        """Retrieve the state for pickling.
151
152
        :class:`sqlite3.Connection` objects are not picklable, so the
153
        `conn` attribute is removed and the connection re-opened upon
154
        unpickling.
155
156
        """
157
        state = self.__dict__.copy()
158
        if '_conn' in state:
159
            del state['_conn']
160
        self.resume()
161
        return state
162
163
    def __getitem__(self, time):
164
        self._check_time(time)
165
        return SQLiteEntry(self, time)
166
167
    def __iter__(self):
168
        return map(itemgetter(0), self.conn.execute(
169
            ANCESTORS_QUERY + "SELECT DISTINCT time FROM entries "
170
            "WHERE uuid IN ancestors ORDER BY time ASC", (self.h_uuid,)
171
        ))
172
173
    def __len__(self):
174
        return self.conn.execute(
175
            ANCESTORS_QUERY + "SELECT COUNT(DISTINCT time) FROM entries "
176
            "WHERE uuid IN ancestors ORDER BY time ASC", (self.h_uuid,)
177
        ).fetchone()[0]
178
179
180
class SQLiteStatus(MutableMapping):
181
    def __init__(self, log):
182
        self.log = log
183
184
    def __getitem__(self, key):
185
        row = self.log.conn.execute(
186
            "SELECT value FROM status WHERE uuid = ? AND key = ?",
187
            (self.log.h_uuid, key)
188
        ).fetchone()
189
        return _get_row(row, key)
190
191
    def __setitem__(self, key, value):
192
        _register_adapter(value, key)
193
        with self.log.conn:
194
            self.log.conn.execute(
195
                "INSERT OR REPLACE INTO status VALUES (?, ?, ?)",
196
                (self.log.h_uuid, key, value)
197
            )
198
199
    def __delitem__(self, key):
200
        with self.log.conn:
201
            self.log.conn.execute(
202
                "DELETE FROM status WHERE uuid = ? AND key = ?",
203
                (self.log.h_uuid, key)
204
            )
205
206
    def __len__(self):
207
        return self.log.conn.execute(
208
            "SELECT COUNT(*) FROM status WHERE uuid = ?",
209
            (self.log.h_uuid,)
210
        ).fetchone()[0]
211
212
    def __iter__(self):
213
        return map(itemgetter(0), self.log.conn.execute(
214
            "SELECT key FROM status WHERE uuid = ?", (self.log.h_uuid,)
215
        ))
216
217
218
class SQLiteEntry(MutableMapping):
219
    """Store log entries in an SQLite database.
220
221
    Each entry is a row with the columns `uuid`, `time` (iterations done),
222
    `key` and `value`. Note that SQLite only supports numeric values,
223
    strings, and bytes (e.g. the `uuid` column), all other objects will be
224
    pickled before being stored.
225
226
    Entries are automatically retrieved from ancestral logs (i.e. logs that
227
    were resumed from).
228
229
    """
230
    def __init__(self, log, time):
231
        self.log = log
232
        self.time = time
233
234
    def __getitem__(self, key):
235
        row = self.log.conn.execute(
236
            ANCESTORS_QUERY + "SELECT value FROM entries "
237
            # JOIN statement should sort things so that the latest is returned
238
            "JOIN ancestors ON entries.uuid = ancestors.parent "
239
            "WHERE uuid IN ancestors AND time = ? AND key = ?",
240
            (self.log.h_uuid, self.time, key)
241
        ).fetchone()
242
        return _get_row(row, key)
243
244
    def __setitem__(self, key, value):
245
        _register_adapter(value, key)
246
        with self.log.conn:
247
            self.log.conn.execute(
248
                "INSERT OR REPLACE INTO entries VALUES (?, ?, ?, ?)",
249
                (self.log.h_uuid, self.time, key, value)
250
            )
251
252
    def __delitem__(self, key):
253
        with self.log.conn:
254
            self.log.conn.execute(
255
                "DELETE FROM entries WHERE uuid = ? AND time = ? AND key = ?",
256
                (self.log.h_uuid, self.time, key)
257
            )
258
259
    def __len__(self):
260
        return self.log.conn.execute(
261
            ANCESTORS_QUERY + "SELECT COUNT(*) FROM entries "
262
            "WHERE uuid IN ancestors AND time = ?",
263
            (self.log.h_uuid, self.time,)
264
        ).fetchone()[0]
265
266
    def __iter__(self):
267
        return map(itemgetter(0), self.log.conn.execute(
268
            ANCESTORS_QUERY + "SELECT key FROM entries "
269
            "WHERE uuid IN ancestors AND time = ?",
270
            (self.log.h_uuid, self.time,)
271
        ))
272