1 | """SQLite backend for the main loop log.""" |
||
2 | import sqlite3 |
||
3 | import warnings |
||
4 | from collections import MutableMapping, Mapping |
||
5 | from operator import itemgetter |
||
6 | |||
7 | import numpy |
||
8 | import six |
||
9 | from six.moves import cPickle, map |
||
0 ignored issues
–
show
|
|||
10 | |||
11 | from blocks.config import config |
||
12 | from .log import TrainingLogBase |
||
13 | |||
14 | |||
15 | ANCESTORS_QUERY = """ |
||
16 | WITH parents (parent, child) AS ( |
||
17 | SELECT uuid, value FROM status |
||
18 | WHERE key = 'resumed_from' AND uuid = ? |
||
19 | UNION ALL |
||
20 | SELECT uuid, value FROM status |
||
21 | INNER JOIN parents ON status.uuid = parents.child |
||
22 | WHERE key = 'resumed_from' |
||
23 | ), |
||
24 | ancestors AS (SELECT parent FROM parents) |
||
25 | """ |
||
26 | |||
27 | LARGE_BLOB_WARNING = """ |
||
28 | |||
29 | A {} object of {} bytes was stored in the SQLite database. SQLite natively \ |
||
30 | only supports numbers and text. Other objects will be pickled before being \ |
||
31 | saved. For large objects, this can be slow and degrade performance of the \ |
||
32 | database.""" |
||
33 | |||
34 | |||
35 | def adapt_obj(obj): |
||
36 | """Binarize objects to be stored in an SQLite database. |
||
37 | |||
38 | Parameters |
||
39 | ---------- |
||
40 | obj : object |
||
41 | Any picklable object. |
||
42 | |||
43 | Returns |
||
44 | ------- |
||
45 | blob : memoryview |
||
46 | A buffer (Python 2) or memoryview (Python 3) of the pickled object |
||
47 | that can be stored as a BLOB in an SQLite database. |
||
48 | |||
49 | """ |
||
50 | blob = sqlite3.Binary(cPickle.dumps(obj)) |
||
51 | if len(blob) > config.max_blob_size: |
||
52 | warnings.warn('large objects stored in SQLite' + |
||
53 | LARGE_BLOB_WARNING.format(type(obj), len(blob))) |
||
54 | # Prevent the warning with variable message from repeating |
||
55 | warnings.filterwarnings('ignore', 'large objects .*') |
||
56 | return blob |
||
57 | |||
58 | |||
59 | def adapt_ndarray(obj): |
||
60 | """Convert NumPy scalars to floats before storing in SQLite. |
||
61 | |||
62 | This makes it easier to inspect the database, and speeds things up. |
||
63 | |||
64 | Parameters |
||
65 | ---------- |
||
66 | obj : ndarray |
||
67 | A NumPy array. |
||
68 | |||
69 | Returns |
||
70 | ------- |
||
71 | float or memoryview |
||
72 | If the array was a scalar, it returns a floating point number. |
||
73 | Otherwise it binarizes the NumPy array using :func:`adapt_obj` |
||
74 | |||
75 | """ |
||
76 | if obj.ndim == 0: |
||
77 | return float(obj) |
||
78 | else: |
||
79 | return adapt_obj(obj) |
||
80 | |||
81 | |||
82 | def _get_row(row, key): |
||
83 | """Handle the returned row e.g. unpickle if needed.""" |
||
84 | if row is not None: |
||
85 | value = row[0] |
||
86 | # Resumption UUIDs are stored as bytes and should not be unpickled |
||
87 | if (isinstance(value, (sqlite3.Binary, bytes)) and |
||
88 | key != 'resumed_from'): |
||
89 | value = cPickle.loads(bytes(value)) |
||
90 | return value |
||
91 | raise KeyError(key) |
||
92 | |||
93 | |||
94 | def _register_adapter(value, key): |
||
95 | """Register an adapter if the type of value is unknown.""" |
||
96 | # Assuming no storage of non-simple types on channel 'resumed_from' |
||
97 | if (not isinstance(value, (type(None), int, float, six.string_types, |
||
98 | bytes, numpy.ndarray)) and |
||
99 | key != 'resumed_from'): |
||
100 | sqlite3.register_adapter(type(value), adapt_obj) |
||
101 | |||
102 | |||
103 | class SQLiteLog(TrainingLogBase, Mapping): |
||
104 | r"""Training log using SQLite as a backend. |
||
105 | |||
106 | Parameters |
||
107 | ---------- |
||
108 | database : str, optional |
||
109 | The database (file) to connect to. Can also be `:memory:`. See |
||
110 | :func:`sqlite3.connect` for details. Uses `config.sqlite_database` |
||
111 | by default. |
||
112 | \*\*kwargs |
||
113 | Arguments to pass to :class:`TrainingLogBase` |
||
114 | |||
115 | """ |
||
116 | def __init__(self, database=None, **kwargs): |
||
117 | if database is None: |
||
118 | database = config.sqlite_database |
||
119 | self.database = database |
||
120 | self.conn = sqlite3.connect(database) |
||
121 | sqlite3.register_adapter(numpy.ndarray, adapt_ndarray) |
||
122 | with self.conn: |
||
123 | self.conn.execute("""CREATE TABLE IF NOT EXISTS entries ( |
||
124 | uuid TEXT NOT NULL, |
||
125 | time INT NOT NULL, |
||
126 | "key" TEXT NOT NULL, |
||
127 | value, |
||
128 | PRIMARY KEY(uuid, time, "key") |
||
129 | );""") |
||
130 | self.conn.execute("""CREATE TABLE IF NOT EXISTS status ( |
||
131 | uuid TEXT NOT NULL, |
||
132 | "key" text NOT NULL, |
||
133 | value, |
||
134 | PRIMARY KEY(uuid, "key") |
||
135 | );""") |
||
136 | self.status = SQLiteStatus(self) |
||
137 | super(SQLiteLog, self).__init__(**kwargs) |
||
138 | |||
139 | @property |
||
140 | def conn(self): |
||
141 | if not hasattr(self, '_conn'): |
||
142 | self._conn = sqlite3.connect(self.database) |
||
143 | return self._conn |
||
144 | |||
145 | @conn.setter |
||
146 | def conn(self, value): |
||
147 | self._conn = value |
||
148 | |||
149 | def __getstate__(self): |
||
150 | """Retrieve the state for pickling. |
||
151 | |||
152 | :class:`sqlite3.Connection` objects are not picklable, so the |
||
153 | `conn` attribute is removed and the connection re-opened upon |
||
154 | unpickling. |
||
155 | |||
156 | """ |
||
157 | state = self.__dict__.copy() |
||
158 | if '_conn' in state: |
||
159 | del state['_conn'] |
||
160 | self.resume() |
||
161 | return state |
||
162 | |||
163 | def __getitem__(self, time): |
||
164 | self._check_time(time) |
||
165 | return SQLiteEntry(self, time) |
||
166 | |||
167 | def __iter__(self): |
||
168 | return map(itemgetter(0), self.conn.execute( |
||
169 | ANCESTORS_QUERY + "SELECT DISTINCT time FROM entries " |
||
170 | "WHERE uuid IN ancestors ORDER BY time ASC", (self.h_uuid,) |
||
171 | )) |
||
172 | |||
173 | def __len__(self): |
||
174 | return self.conn.execute( |
||
175 | ANCESTORS_QUERY + "SELECT COUNT(DISTINCT time) FROM entries " |
||
176 | "WHERE uuid IN ancestors ORDER BY time ASC", (self.h_uuid,) |
||
177 | ).fetchone()[0] |
||
178 | |||
179 | |||
180 | class SQLiteStatus(MutableMapping): |
||
181 | def __init__(self, log): |
||
182 | self.log = log |
||
183 | |||
184 | def __getitem__(self, key): |
||
185 | row = self.log.conn.execute( |
||
186 | "SELECT value FROM status WHERE uuid = ? AND key = ?", |
||
187 | (self.log.h_uuid, key) |
||
188 | ).fetchone() |
||
189 | return _get_row(row, key) |
||
190 | |||
191 | def __setitem__(self, key, value): |
||
192 | _register_adapter(value, key) |
||
193 | with self.log.conn: |
||
194 | self.log.conn.execute( |
||
195 | "INSERT OR REPLACE INTO status VALUES (?, ?, ?)", |
||
196 | (self.log.h_uuid, key, value) |
||
197 | ) |
||
198 | |||
199 | def __delitem__(self, key): |
||
200 | with self.log.conn: |
||
201 | self.log.conn.execute( |
||
202 | "DELETE FROM status WHERE uuid = ? AND key = ?", |
||
203 | (self.log.h_uuid, key) |
||
204 | ) |
||
205 | |||
206 | def __len__(self): |
||
207 | return self.log.conn.execute( |
||
208 | "SELECT COUNT(*) FROM status WHERE uuid = ?", |
||
209 | (self.log.h_uuid,) |
||
210 | ).fetchone()[0] |
||
211 | |||
212 | def __iter__(self): |
||
213 | return map(itemgetter(0), self.log.conn.execute( |
||
214 | "SELECT key FROM status WHERE uuid = ?", (self.log.h_uuid,) |
||
215 | )) |
||
216 | |||
217 | |||
218 | class SQLiteEntry(MutableMapping): |
||
219 | """Store log entries in an SQLite database. |
||
220 | |||
221 | Each entry is a row with the columns `uuid`, `time` (iterations done), |
||
222 | `key` and `value`. Note that SQLite only supports numeric values, |
||
223 | strings, and bytes (e.g. the `uuid` column), all other objects will be |
||
224 | pickled before being stored. |
||
225 | |||
226 | Entries are automatically retrieved from ancestral logs (i.e. logs that |
||
227 | were resumed from). |
||
228 | |||
229 | """ |
||
230 | def __init__(self, log, time): |
||
231 | self.log = log |
||
232 | self.time = time |
||
233 | |||
234 | def __getitem__(self, key): |
||
235 | row = self.log.conn.execute( |
||
236 | ANCESTORS_QUERY + "SELECT value FROM entries " |
||
237 | # JOIN statement should sort things so that the latest is returned |
||
238 | "JOIN ancestors ON entries.uuid = ancestors.parent " |
||
239 | "WHERE uuid IN ancestors AND time = ? AND key = ?", |
||
240 | (self.log.h_uuid, self.time, key) |
||
241 | ).fetchone() |
||
242 | return _get_row(row, key) |
||
243 | |||
244 | def __setitem__(self, key, value): |
||
245 | _register_adapter(value, key) |
||
246 | with self.log.conn: |
||
247 | self.log.conn.execute( |
||
248 | "INSERT OR REPLACE INTO entries VALUES (?, ?, ?, ?)", |
||
249 | (self.log.h_uuid, self.time, key, value) |
||
250 | ) |
||
251 | |||
252 | def __delitem__(self, key): |
||
253 | with self.log.conn: |
||
254 | self.log.conn.execute( |
||
255 | "DELETE FROM entries WHERE uuid = ? AND time = ? AND key = ?", |
||
256 | (self.log.h_uuid, self.time, key) |
||
257 | ) |
||
258 | |||
259 | def __len__(self): |
||
260 | return self.log.conn.execute( |
||
261 | ANCESTORS_QUERY + "SELECT COUNT(*) FROM entries " |
||
262 | "WHERE uuid IN ancestors AND time = ?", |
||
263 | (self.log.h_uuid, self.time,) |
||
264 | ).fetchone()[0] |
||
265 | |||
266 | def __iter__(self): |
||
267 | return map(itemgetter(0), self.log.conn.execute( |
||
268 | ANCESTORS_QUERY + "SELECT key FROM entries " |
||
269 | "WHERE uuid IN ancestors AND time = ?", |
||
270 | (self.log.h_uuid, self.time,) |
||
271 | )) |
||
272 |
It is generally discouraged to redefine built-ins as this makes code very hard to read.