1
|
|
|
from .contextlibbackport import ContextDecorator |
2
|
|
|
from .internal import ContextMethodDecorator |
3
|
|
|
import sacred.optional as opt |
4
|
|
|
if opt.has_tensorflow: |
5
|
|
|
import tensorflow |
6
|
|
|
else: |
7
|
|
|
tensorflow = None |
8
|
|
|
|
9
|
|
|
|
10
|
|
|
class LogFileWriter(ContextDecorator, ContextMethodDecorator): |
11
|
|
|
""" |
12
|
|
|
Intercept ``logdir`` each time a new ``FileWriter`` instance is created. |
13
|
|
|
|
14
|
|
|
:param experiment: Tensorflow experiment. |
15
|
|
|
|
16
|
|
|
The state of the experiment must be running when entering the annotated |
17
|
|
|
function / the context manager. |
18
|
|
|
|
19
|
|
|
When creating ``FileWriters`` in Tensorflow, you might want to |
20
|
|
|
store the path to the produced log files in the sacred database. |
21
|
|
|
|
22
|
|
|
In the scope of ``LogFileWriter``, the corresponding log directory path |
23
|
|
|
is appended to a list in experiment.info["tensorflow"]["logdirs"]. |
24
|
|
|
|
25
|
|
|
``LogFileWriter`` can be used both as a context manager or as |
26
|
|
|
an annotation (decorator) on a function. |
27
|
|
|
|
28
|
|
|
|
29
|
|
|
Example usage as decorator:: |
30
|
|
|
|
31
|
|
|
ex = Experiment("my experiment") |
32
|
|
|
@LogFileWriter(ex) |
33
|
|
|
def run_experiment(_run): |
34
|
|
|
with tf.Session() as s: |
35
|
|
|
swr = tf.summary.FileWriter("/tmp/1", s.graph) |
36
|
|
|
# _run.info["tensorflow"]["logdirs"] == ["/tmp/1"] |
37
|
|
|
swr2 tf.summary.FileWriter("./test", s.graph) |
38
|
|
|
#_run.info["tensorflow"]["logdirs"] == ["/tmp/1", "./test"] |
39
|
|
|
|
40
|
|
|
|
41
|
|
|
Example usage as context manager:: |
42
|
|
|
|
43
|
|
|
ex = Experiment("my experiment") |
44
|
|
|
def run_experiment(_run): |
45
|
|
|
with tf.Session() as s: |
46
|
|
|
with LogFileWriter(ex): |
47
|
|
|
swr = tf.summary.FileWriter("/tmp/1", s.graph) |
48
|
|
|
# _run.info["tensorflow"]["logdirs"] == ["/tmp/1"] |
49
|
|
|
swr3 = tf.summary.FileWriter("./test", s.graph) |
50
|
|
|
#_run.info["tensorflow"]["logdirs"] == ["/tmp/1", "./test"] |
51
|
|
|
# This is called outside the scope and won't be captured |
52
|
|
|
swr3 = tf.summary.FileWriter("./nothing", s.graph) |
53
|
|
|
# Nothing has changed: |
54
|
|
|
#_run.info["tensorflow"]["logdirs"] == ["/tmp/1", "./test"] |
55
|
|
|
|
56
|
|
|
""" |
57
|
|
|
|
58
|
|
|
def __init__(self, experiment): |
59
|
|
|
self.experiment = experiment |
60
|
|
|
|
61
|
|
|
def log_writer_decorator(instance, original_method, original_args, |
62
|
|
|
original_kwargs): |
63
|
|
|
result = original_method(instance, *original_args, |
64
|
|
|
**original_kwargs) |
65
|
|
|
if "logdir" in original_kwargs: |
66
|
|
|
logdir = original_kwargs["logdir"] |
67
|
|
|
else: |
68
|
|
|
logdir = original_args[0] |
69
|
|
|
self.experiment.info.setdefault("tensorflow", {}).setdefault( |
70
|
|
|
"logdirs", []).append(logdir) |
71
|
|
|
return result |
72
|
|
|
|
73
|
|
|
ContextMethodDecorator.__init__(self, |
74
|
|
|
tensorflow.summary.FileWriter, |
75
|
|
|
"__init__", |
76
|
|
|
log_writer_decorator) |
77
|
|
|
|