1 | #!/usr/bin/env python |
||
2 | # coding=utf-8 |
||
3 | from __future__ import division, print_function, unicode_literals |
||
4 | import hashlib |
||
5 | import json |
||
6 | import os |
||
7 | from threading import Lock |
||
8 | |||
9 | import sacred.optional as opt |
||
10 | from sacred.commandline_options import CommandLineOption |
||
11 | from sacred.dependencies import get_digest |
||
12 | from sacred.observers.base import RunObserver |
||
13 | from sacred.serializer import flatten, restore |
||
14 | |||
15 | DEFAULT_SQL_PRIORITY = 40 |
||
16 | |||
17 | |||
18 | # ############################# Observer #################################### # |
||
19 | |||
20 | class SqlObserver(RunObserver): |
||
21 | @classmethod |
||
22 | def create(cls, url, echo=False, priority=DEFAULT_SQL_PRIORITY): |
||
23 | engine = sa.create_engine(url, echo=echo) |
||
24 | session_factory = sessionmaker(bind=engine) |
||
25 | # make session thread-local to avoid problems with sqlite (see #275) |
||
26 | session = scoped_session(session_factory) |
||
27 | return cls(engine, session, priority) |
||
28 | |||
29 | def __init__(self, engine, session, priority=DEFAULT_SQL_PRIORITY): |
||
30 | self.engine = engine |
||
31 | self.session = session |
||
32 | self.priority = priority |
||
33 | View Code Duplication | self.run = None |
|
0 ignored issues
–
show
Duplication
introduced
by
Loading history...
|
|||
34 | self.lock = Lock() |
||
35 | |||
36 | def started_event(self, ex_info, command, host_info, start_time, config, |
||
37 | meta_info, _id): |
||
38 | Base.metadata.create_all(self.engine) |
||
39 | sql_exp = Experiment.get_or_create(ex_info, self.session) |
||
40 | sql_host = Host.get_or_create(host_info, self.session) |
||
41 | if _id is None: |
||
42 | i = self.session.query(Run).order_by(Run.id.desc()).first() |
||
43 | _id = 0 if i is None else i.id + 1 |
||
44 | |||
45 | self.run = Run(run_id=str(_id), |
||
46 | start_time=start_time, |
||
47 | config=json.dumps(flatten(config)), |
||
48 | command=command, |
||
49 | priority=meta_info.get('priority', 0), |
||
50 | comment=meta_info.get('comment', ''), |
||
51 | experiment=sql_exp, |
||
52 | host=sql_host, |
||
53 | status='RUNNING') |
||
54 | self.session.add(self.run) |
||
55 | View Code Duplication | self.save() |
|
0 ignored issues
–
show
|
|||
56 | return _id or self.run.run_id |
||
57 | |||
58 | def queued_event(self, ex_info, command, host_info, queue_time, config, |
||
59 | meta_info, _id): |
||
60 | |||
61 | Base.metadata.create_all(self.engine) |
||
62 | sql_exp = Experiment.get_or_create(ex_info, self.session) |
||
63 | sql_host = Host.get_or_create(host_info, self.session) |
||
64 | if _id is None: |
||
65 | i = self.session.query(Run).order_by(Run.id.desc()).first() |
||
66 | _id = 0 if i is None else i.id + 1 |
||
67 | |||
68 | self.run = Run(run_id=str(_id), |
||
69 | config=json.dumps(flatten(config)), |
||
70 | command=command, |
||
71 | priority=meta_info.get('priority', 0), |
||
72 | comment=meta_info.get('comment', ''), |
||
73 | experiment=sql_exp, |
||
74 | host=sql_host, |
||
75 | status='QUEUED') |
||
76 | self.session.add(self.run) |
||
77 | self.save() |
||
78 | return _id or self.run.run_id |
||
79 | |||
80 | def heartbeat_event(self, info, captured_out, beat_time, result): |
||
81 | self.run.info = json.dumps(flatten(info)) |
||
82 | self.run.captured_out = captured_out |
||
83 | self.run.heartbeat = beat_time |
||
84 | self.run.result = result |
||
85 | self.save() |
||
86 | |||
87 | def completed_event(self, stop_time, result): |
||
88 | self.run.stop_time = stop_time |
||
89 | self.run.result = result |
||
90 | self.run.status = 'COMPLETED' |
||
91 | self.save() |
||
92 | |||
93 | def interrupted_event(self, interrupt_time, status): |
||
94 | self.run.stop_time = interrupt_time |
||
95 | self.run.status = status |
||
96 | self.save() |
||
97 | |||
98 | def failed_event(self, fail_time, fail_trace): |
||
99 | self.run.stop_time = fail_time |
||
100 | self.run.fail_trace = '\n'.join(fail_trace) |
||
101 | self.run.status = 'FAILED' |
||
102 | self.save() |
||
103 | |||
104 | def resource_event(self, filename): |
||
105 | res = Resource.get_or_create(filename, self.session) |
||
106 | self.run.resources.append(res) |
||
107 | self.save() |
||
108 | |||
109 | def artifact_event(self, name, filename): |
||
110 | a = Artifact.create(name, filename) |
||
111 | self.run.artifacts.append(a) |
||
112 | self.save() |
||
113 | |||
114 | def save(self): |
||
115 | with self.lock: |
||
116 | self.session.commit() |
||
117 | |||
118 | def query(self, _id): |
||
119 | run = self.session.query(Run).filter_by(id=_id).first() |
||
120 | return run.to_json() |
||
121 | |||
122 | def __eq__(self, other): |
||
123 | if isinstance(other, SqlObserver): |
||
124 | # fixme: this will probably fail to detect two equivalent engines |
||
125 | return (self.engine == other.engine and |
||
126 | self.session == other.session) |
||
127 | return False |
||
128 | |||
129 | def __ne__(self, other): |
||
130 | return not self.__eq__(other) |
||
131 | |||
132 | |||
133 | # ######################## Commandline Option ############################### # |
||
134 | |||
135 | class SqlOption(CommandLineOption): |
||
136 | """Add a SQL Observer to the experiment.""" |
||
137 | __depends_on__ = 'sqlalchemy' |
||
138 | |||
139 | arg = 'DB_URL' |
||
140 | arg_description = \ |
||
141 | "The typical form is: dialect://username:password@host:port/database" |
||
142 | |||
143 | @classmethod |
||
144 | def apply(cls, args, run): |
||
145 | run.observers.append(SqlObserver.create(args)) |
||
146 | |||
147 | |||
148 | # ################################ ORM ###################################### # |
||
149 | if opt.has_sqlalchemy: # noqa |
||
150 | import sqlalchemy as sa |
||
151 | from sqlalchemy.ext.declarative import declarative_base |
||
152 | from sqlalchemy.orm import sessionmaker, scoped_session |
||
153 | |||
154 | Base = declarative_base() |
||
155 | |||
156 | class Source(Base): |
||
157 | __tablename__ = 'source' |
||
158 | |||
159 | @classmethod |
||
160 | def get_or_create(cls, filename, md5sum, session, basedir): |
||
161 | instance = session.query(cls).filter_by(filename=filename, |
||
162 | md5sum=md5sum).first() |
||
163 | if instance: |
||
164 | return instance |
||
165 | full_path = os.path.join(basedir, filename) |
||
166 | md5sum_ = get_digest(full_path) |
||
167 | assert md5sum_ == md5sum, 'found md5 mismatch for {}: {} != {}'\ |
||
168 | .format(filename, md5sum, md5sum_) |
||
169 | with open(full_path, 'r') as f: |
||
170 | return cls(filename=filename, md5sum=md5sum, content=f.read()) |
||
171 | |||
172 | source_id = sa.Column(sa.Integer, primary_key=True) |
||
173 | filename = sa.Column(sa.String(256)) |
||
174 | md5sum = sa.Column(sa.String(32)) |
||
175 | content = sa.Column(sa.Text) |
||
176 | |||
177 | def to_json(self): |
||
178 | return {'filename': self.filename, |
||
179 | 'md5sum': self.md5sum} |
||
180 | |||
181 | class Dependency(Base): |
||
182 | __tablename__ = 'dependency' |
||
183 | |||
184 | @classmethod |
||
185 | def get_or_create(cls, dep, session): |
||
186 | name, _, version = dep.partition('==') |
||
187 | instance = session.query(cls).filter_by(name=name, |
||
188 | version=version).first() |
||
189 | View Code Duplication | if instance: |
|
0 ignored issues
–
show
|
|||
190 | return instance |
||
191 | return cls(name=name, version=version) |
||
192 | |||
193 | dependency_id = sa.Column(sa.Integer, primary_key=True) |
||
194 | name = sa.Column(sa.String(32)) |
||
195 | version = sa.Column(sa.String(16)) |
||
196 | |||
197 | def to_json(self): |
||
198 | return "{}=={}".format(self.name, self.version) |
||
199 | |||
200 | class Artifact(Base): |
||
201 | __tablename__ = 'artifact' |
||
202 | |||
203 | @classmethod |
||
204 | def create(cls, name, filename): |
||
205 | with open(filename, 'rb') as f: |
||
206 | return cls(filename=name, content=f.read()) |
||
207 | |||
208 | View Code Duplication | artifact_id = sa.Column(sa.Integer, primary_key=True) |
|
0 ignored issues
–
show
|
|||
209 | filename = sa.Column(sa.String(64)) |
||
210 | content = sa.Column(sa.LargeBinary) |
||
211 | |||
212 | run_id = sa.Column(sa.String(24), sa.ForeignKey('run.run_id')) |
||
213 | run = sa.orm.relationship("Run", backref=sa.orm.backref('artifacts')) |
||
214 | |||
215 | def to_json(self): |
||
216 | return {'_id': self.artifact_id, |
||
217 | 'filename': self.filename} |
||
218 | |||
219 | class Resource(Base): |
||
220 | __tablename__ = 'resource' |
||
221 | |||
222 | @classmethod |
||
223 | def get_or_create(cls, filename, session): |
||
224 | md5sum = get_digest(filename) |
||
225 | instance = session.query(cls).filter_by(filename=filename, |
||
226 | md5sum=md5sum).first() |
||
227 | if instance: |
||
228 | return instance |
||
229 | with open(filename, 'rb') as f: |
||
230 | return cls(filename=filename, md5sum=md5sum, content=f.read()) |
||
231 | |||
232 | resource_id = sa.Column(sa.Integer, primary_key=True) |
||
233 | filename = sa.Column(sa.String(256)) |
||
234 | md5sum = sa.Column(sa.String(32)) |
||
235 | content = sa.Column(sa.LargeBinary) |
||
236 | |||
237 | def to_json(self): |
||
238 | return {'filename': self.filename, |
||
239 | 'md5sum': self.md5sum} |
||
240 | |||
241 | class Host(Base): |
||
242 | __tablename__ = 'host' |
||
243 | |||
244 | @classmethod |
||
245 | def get_or_create(cls, host_info, session): |
||
246 | h = dict( |
||
247 | hostname=host_info['hostname'], |
||
248 | cpu=host_info['cpu'], |
||
249 | os=host_info['os'][0], |
||
250 | os_info=host_info['os'][1], |
||
251 | python_version=host_info['python_version'] |
||
252 | ) |
||
253 | |||
254 | return session.query(cls).filter_by(**h).first() or cls(**h) |
||
255 | |||
256 | host_id = sa.Column(sa.Integer, primary_key=True) |
||
257 | cpu = sa.Column(sa.String(64)) |
||
258 | hostname = sa.Column(sa.String(64)) |
||
259 | os = sa.Column(sa.String(16)) |
||
260 | os_info = sa.Column(sa.String(64)) |
||
261 | python_version = sa.Column(sa.String(16)) |
||
262 | |||
263 | def to_json(self): |
||
264 | return {'cpu': self.cpu, |
||
265 | 'hostname': self.hostname, |
||
266 | 'os': [self.os, self.os_info], |
||
267 | 'python_version': self.python_version} |
||
268 | |||
269 | experiment_source_association = sa.Table( |
||
270 | 'experiments_sources', Base.metadata, |
||
271 | sa.Column('experiment_id', sa.Integer, |
||
272 | sa.ForeignKey('experiment.experiment_id')), |
||
273 | sa.Column('source_id', sa.Integer, sa.ForeignKey('source.source_id')) |
||
274 | ) |
||
275 | |||
276 | experiment_dependency_association = sa.Table( |
||
277 | 'experiments_dependencies', Base.metadata, |
||
278 | sa.Column('experiment_id', sa.Integer, |
||
279 | sa.ForeignKey('experiment.experiment_id')), |
||
280 | sa.Column('dependency_id', sa.Integer, |
||
281 | sa.ForeignKey('dependency.dependency_id')) |
||
282 | ) |
||
283 | |||
284 | class Experiment(Base): |
||
285 | __tablename__ = 'experiment' |
||
286 | |||
287 | @classmethod |
||
288 | def get_or_create(cls, ex_info, session): |
||
289 | name = ex_info['name'] |
||
290 | # Compute a MD5sum of the ex_info to determine its uniqueness |
||
291 | h = hashlib.md5() |
||
292 | h.update(json.dumps(ex_info).encode()) |
||
293 | md5 = h.hexdigest() |
||
294 | instance = session.query(cls).filter_by(name=name, |
||
295 | md5sum=md5).first() |
||
296 | if instance: |
||
297 | return instance |
||
298 | |||
299 | dependencies = [Dependency.get_or_create(d, session) |
||
300 | for d in ex_info['dependencies']] |
||
301 | sources = [Source.get_or_create(s, md5sum, session, |
||
302 | ex_info['base_dir']) |
||
303 | for s, md5sum in ex_info['sources']] |
||
304 | |||
305 | return cls(name=name, dependencies=dependencies, sources=sources, |
||
306 | md5sum=md5, base_dir=ex_info['base_dir']) |
||
307 | |||
308 | experiment_id = sa.Column(sa.Integer, primary_key=True) |
||
309 | name = sa.Column(sa.String(32)) |
||
310 | md5sum = sa.Column(sa.String(32)) |
||
311 | base_dir = sa.Column(sa.String(64)) |
||
312 | sources = sa.orm.relationship("Source", |
||
313 | secondary=experiment_source_association, |
||
314 | backref="experiments") |
||
315 | dependencies = sa.orm.relationship( |
||
316 | "Dependency", |
||
317 | secondary=experiment_dependency_association, |
||
318 | backref="experiments") |
||
319 | |||
320 | def to_json(self): |
||
321 | return {'name': self.name, |
||
322 | 'base_dir': self.base_dir, |
||
323 | 'sources': [s.to_json() for s in self.sources], |
||
324 | 'dependencies': [d.to_json() for d in self.dependencies]} |
||
325 | |||
326 | run_resource_association = sa.Table( |
||
327 | 'runs_resources', Base.metadata, |
||
328 | sa.Column('run_id', sa.String(24), sa.ForeignKey('run.run_id')), |
||
329 | sa.Column('resource_id', sa.Integer, |
||
330 | sa.ForeignKey('resource.resource_id')) |
||
331 | ) |
||
332 | |||
333 | class Run(Base): |
||
334 | __tablename__ = 'run' |
||
335 | id = sa.Column(sa.Integer, primary_key=True) |
||
336 | |||
337 | run_id = sa.Column(sa.String(24), unique=True) |
||
338 | |||
339 | command = sa.Column(sa.String(64)) |
||
340 | |||
341 | # times |
||
342 | start_time = sa.Column(sa.DateTime) |
||
343 | heartbeat = sa.Column(sa.DateTime) |
||
344 | stop_time = sa.Column(sa.DateTime) |
||
345 | queue_time = sa.Column(sa.DateTime) |
||
346 | |||
347 | # meta info |
||
348 | priority = sa.Column(sa.Float) |
||
349 | comment = sa.Column(sa.Text) |
||
350 | |||
351 | fail_trace = sa.Column(sa.Text) |
||
352 | |||
353 | # Captured out |
||
354 | # TODO: move to separate table? |
||
355 | captured_out = sa.Column(sa.Text) |
||
356 | |||
357 | # Configuration & info |
||
358 | # TODO: switch type to json if possible |
||
359 | config = sa.Column(sa.Text) |
||
360 | info = sa.Column(sa.Text) |
||
361 | |||
362 | status = sa.Column(sa.Enum("RUNNING", "COMPLETED", "INTERRUPTED", |
||
363 | "TIMEOUT", "FAILED", name="status_enum")) |
||
364 | |||
365 | host_id = sa.Column(sa.Integer, sa.ForeignKey('host.host_id')) |
||
366 | host = sa.orm.relationship("Host", backref=sa.orm.backref('runs')) |
||
367 | |||
368 | experiment_id = sa.Column(sa.Integer, |
||
369 | sa.ForeignKey('experiment.experiment_id')) |
||
370 | experiment = sa.orm.relationship("Experiment", |
||
371 | backref=sa.orm.backref('runs')) |
||
372 | |||
373 | # artifacts = backref |
||
374 | resources = sa.orm.relationship("Resource", |
||
375 | secondary=run_resource_association, |
||
376 | backref="runs") |
||
377 | |||
378 | result = sa.Column(sa.Float) |
||
379 | |||
380 | def to_json(self): |
||
381 | return { |
||
382 | '_id': self.run_id, |
||
383 | 'command': self.command, |
||
384 | 'start_time': self.start_time, |
||
385 | 'heartbeat': self.heartbeat, |
||
386 | 'stop_time': self.stop_time, |
||
387 | 'queue_time': self.queue_time, |
||
388 | 'status': self.status, |
||
389 | 'result': self.result, |
||
390 | 'meta': { |
||
391 | 'comment': self.comment, |
||
392 | 'priority': self.priority}, |
||
393 | 'resources': [r.to_json() for r in self.resources], |
||
394 | 'artifacts': [a.to_json() for a in self.artifacts], |
||
395 | 'host': self.host.to_json(), |
||
396 | 'experiment': self.experiment.to_json(), |
||
397 | 'config': restore(json.loads(self.config)), |
||
398 | 'captured_out': self.captured_out, |
||
399 | 'fail_trace': self.fail_trace, |
||
400 | } |
||
401 |