Completed
Pull Request — master (#1130)
by
unknown
04:55
created

SQLiteLog.__getitem__()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""SQLite backend for the main loop log."""
2
import sqlite3
3
import warnings
4
from collections import MutableMapping, Mapping
5
from operator import itemgetter
6
7
import numpy
8
import six
9
from six.moves import cPickle, map
0 ignored issues
show
Bug Best Practice introduced by
This seems to re-define the built-in map.

It is generally discouraged to redefine built-ins as this makes code very hard to read.

Loading history...
10
11
from blocks.config import config
12
from .log import TrainingLogBase
13
14
15
ANCESTORS_QUERY = """
16
WITH parents (parent, child) AS (
17
    SELECT uuid, value FROM status
18
    WHERE key = 'resumed_from' AND uuid = ?
19
    UNION ALL
20
    SELECT uuid, value FROM status
21
    INNER JOIN parents ON status.uuid = parents.child
22
    WHERE key = 'resumed_from'
23
),
24
ancestors AS (SELECT parent FROM parents)
25
"""
26
27
LARGE_BLOB_WARNING = """
28
29
A {} object of {} bytes was stored in the SQLite database. SQLite natively \
30
only supports numbers and text. Other objects will be pickled before being \
31
saved. For large objects, this can be slow and degrade performance of the \
32
database."""
33
34
35
def adapt_obj(obj):
36
    """Binarize objects to be stored in an SQLite database.
37
38
    Parameters
39
    ----------
40
    obj : object
41
        Any picklable object.
42
43
    Returns
44
    -------
45
    blob : memoryview
46
        A buffer (Python 2) or memoryview (Python 3) of the pickled object
47
        that can be stored as a BLOB in an SQLite database.
48
49
    """
50
    blob = sqlite3.Binary(cPickle.dumps(obj))
51
    if len(blob) > config.max_blob_size:
52
        warnings.warn('large objects stored in SQLite' +
53
                      LARGE_BLOB_WARNING.format(type(obj), len(blob)))
54
        # Prevent the warning with variable message from repeating
55
        warnings.filterwarnings('ignore', 'large objects .*')
56
    return blob
57
58
59
def adapt_ndarray(obj):
60
    """Convert NumPy scalars to floats before storing in SQLite.
61
62
    This makes it easier to inspect the database, and speeds things up.
63
64
    Parameters
65
    ----------
66
    obj : ndarray
67
        A NumPy array.
68
69
    Returns
70
    -------
71
    float or memoryview
72
        If the array was a scalar, it returns a floating point number.
73
        Otherwise it binarizes the NumPy array using :func:`adapt_obj`
74
75
    """
76
    if obj.ndim == 0:
77
        return float(obj)
78
    else:
79
        return adapt_obj(obj)
80
81
82
def _get_row(row, key):
83
    """Handle the returned row e.g. unpickle if needed."""
84
    if row is not None:
85
        value = row[0]
86
        # Resumption UUIDs are stored as bytes and should not be unpickled
87
        if (isinstance(value, (sqlite3.Binary, bytes)) and
88
                key != 'resumed_from'):
89
            value = cPickle.loads(bytes(value))
90
        return value
91
    raise KeyError(key)
92
93
94
def _register_adapter(value, key):
95
    """Register an adapter if the type of value is unknown."""
96
    # Assuming no storage of non-simple types on channel 'resumed_from'
97
    if (not isinstance(value, (type(None), int, float, six.string_types,
98
                               bytes, numpy.ndarray)) and
99
            key != 'resumed_from'):
100
        sqlite3.register_adapter(type(value), adapt_obj)
101
102
103
class SQLiteLog(TrainingLogBase, Mapping):
104
    r"""Training log using SQLite as a backend.
105
106
    Parameters
107
    ----------
108
    database : str, optional
109
        The database (file) to connect to. Can also be `:memory:`. See
110
        :func:`sqlite3.connect` for details. Uses `config.sqlite_database`
111
        by default.
112
    \*\*kwargs
113
        Arguments to pass to :class:`TrainingLogBase`
114
115
    """
116
    def __init__(self, database=None, **kwargs):
117
        if database is None:
118
            database = config.sqlite_database
119
        self.database = database
120
        self.status = SQLiteStatus(self)
121
        self.conn = sqlite3.connect(self.database)
122
        sqlite3.register_adapter(numpy.ndarray, adapt_ndarray)
123
        with self.conn:
124
            self.conn.execute("""CREATE TABLE IF NOT EXISTS entries (
125
                                   uuid TEXT NOT NULL,
126
                                   time INT NOT NULL,
127
                                   "key" TEXT NOT NULL,
128
                                   value,
129
                                   PRIMARY KEY(uuid, time, "key")
130
                                 );""")
131
            self.conn.execute("""CREATE TABLE IF NOT EXISTS status (
132
                                   uuid TEXT NOT NULL,
133
                                   "key" text NOT NULL,
134
                                   value,
135
                                   PRIMARY KEY(uuid, "key")
136
                                 );""")
137
        super(SQLiteLog, self).__init__(**kwargs)
138
139
    @property
140
    def conn(self):
141
        if not hasattr(self, '_conn'):
142
            self._conn = sqlite3.connect(self.database)
143
        return self._conn
144
145
    @conn.setter
146
    def conn(self, value):
147
        self._conn = value
148
149
    def __getstate__(self):
150
        """Retrieve the state for pickling.
151
152
        :class:`sqlite3.Connection` objects are not picklable, so the
153
        `conn` attribute is removed and the connection re-opened upon
154
        unpickling.
155
156
        """
157
        state = self.__dict__.copy()
158
        if '_conn' in state:
159
            del state['_conn']
160
        self.resume()
161
        return state
162
163
    def __getitem__(self, time):
164
        self._check_time(time)
165
        return SQLiteEntry(self, time)
166
167
    def __iter__(self):
168
        return map(itemgetter(0), self.conn.execute(
169
            ANCESTORS_QUERY + "SELECT DISTINCT time FROM entries "
170
            "WHERE uuid IN ancestors ORDER BY time ASC", (self.h_uuid,)
171
        ))
172
173
    def __len__(self):
174
        return self.conn.execute(
175
            ANCESTORS_QUERY + "SELECT COUNT(DISTINCT time) FROM entries "
176
            "WHERE uuid IN ancestors ORDER BY time ASC", (self.h_uuid,)
177
        ).fetchone()[0]
178
179
    def writer(self):
180
        return self
181
182
    def reader(self):
183
        return self
184
185
    def __enter__(self):
186
        return self
187
188
    def __exit__(self, exc_type, exc_val, exc_tb):
189
        pass
190
191
192
class SQLiteStatus(MutableMapping):
193
    def __init__(self, log):
194
        self.log = log
195
196
    def __getitem__(self, key):
197
        row = self.log.conn.execute(
198
            "SELECT value FROM status WHERE uuid = ? AND key = ?",
199
            (self.log.h_uuid, key)
200
        ).fetchone()
201
        return _get_row(row, key)
202
203
    def __setitem__(self, key, value):
204
        _register_adapter(value, key)
205
        with self.log.conn:
206
            self.log.conn.execute(
207
                "INSERT OR REPLACE INTO status VALUES (?, ?, ?)",
208
                (self.log.h_uuid, key, value)
209
            )
210
211
    def __delitem__(self, key):
212
        with self.log.conn:
213
            self.log.conn.execute(
214
                "DELETE FROM status WHERE uuid = ? AND key = ?",
215
                (self.log.h_uuid, key)
216
            )
217
218
    def __len__(self):
219
        return self.log.conn.execute(
220
            "SELECT COUNT(*) FROM status WHERE uuid = ?",
221
            (self.log.h_uuid,)
222
        ).fetchone()[0]
223
224
    def __iter__(self):
225
        return map(itemgetter(0), self.log.conn.execute(
226
            "SELECT key FROM status WHERE uuid = ?", (self.log.h_uuid,)
227
        ))
228
229
230
class SQLiteEntry(MutableMapping):
231
    """Store log entries in an SQLite database.
232
233
    Each entry is a row with the columns `uuid`, `time` (iterations done),
234
    `key` and `value`. Note that SQLite only supports numeric values,
235
    strings, and bytes (e.g. the `uuid` column), all other objects will be
236
    pickled before being stored.
237
238
    Entries are automatically retrieved from ancestral logs (i.e. logs that
239
    were resumed from).
240
241
    """
242
    def __init__(self, log, time):
243
        self.log = log
244
        self.time = time
245
246
    def __getitem__(self, key):
247
        row = self.log.conn.execute(
248
            ANCESTORS_QUERY + "SELECT value FROM entries "
249
            # JOIN statement should sort things so that the latest is returned
250
            "JOIN ancestors ON entries.uuid = ancestors.parent "
251
            "WHERE uuid IN ancestors AND time = ? AND key = ?",
252
            (self.log.h_uuid, self.time, key)
253
        ).fetchone()
254
        return _get_row(row, key)
255
256
    def __setitem__(self, key, value):
257
        _register_adapter(value, key)
258
        with self.log.conn:
259
            self.log.conn.execute(
260
                "INSERT OR REPLACE INTO entries VALUES (?, ?, ?, ?)",
261
                (self.log.h_uuid, self.time, key, value)
262
            )
263
264
    def __delitem__(self, key):
265
        with self.log.conn:
266
            self.log.conn.execute(
267
                "DELETE FROM entries WHERE uuid = ? AND time = ? AND key = ?",
268
                (self.log.h_uuid, self.time, key)
269
            )
270
271
    def __len__(self):
272
        return self.log.conn.execute(
273
            ANCESTORS_QUERY + "SELECT COUNT(*) FROM entries "
274
            "WHERE uuid IN ancestors AND time = ?",
275
            (self.log.h_uuid, self.time,)
276
        ).fetchone()[0]
277
278
    def __iter__(self):
279
        return map(itemgetter(0), self.log.conn.execute(
280
            ANCESTORS_QUERY + "SELECT key FROM entries "
281
            "WHERE uuid IN ancestors AND time = ?",
282
            (self.log.h_uuid, self.time,)
283
        ))
284