Completed
Pull Request — master (#184)
by Martin
45s
created

test_log_summary_writer_class()   F

Complexity

Conditions 9

Size

Total Lines 32

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 9
c 2
b 0
f 0
dl 0
loc 32
rs 3

2 Methods

Rating   Name   Duplication   Size   Complexity  
A FooClass.__init__() 0 2 1
A FooClass.hello() 0 4 2
1
# -*- coding: utf8 -*-
2
import pytest
3
4
from sacred import Experiment
5
from sacred.stflow import LogFileWriter
6
7
8
@pytest.fixture
9
def ex():
10
    return Experiment('tensorflow_tests')
11
12
13
# Creates a simplified tensorflow interface if necessary
14
# so tensorflow is not required during the tests
15
@pytest.fixture()
16
def tf():
17
    from sacred.optional import has_tensorflow, tensorflow
18
    if has_tensorflow:
19
        return tensorflow
20
    else:
21
        # Let's define a mocked tensorflow
22
        class tensorflow():
23
            class summary():
24
                class FileWriter():
25
                    def __init__(self, logdir, graph):
26
                        self.logdir = logdir
27
                        self.graph = graph
28
                        print("Mocked FileWriter got logdir=%s, graph=%s" % (logdir, graph))
29
30
            class Session():
31
                def __init__(self):
32
                    self.graph = None
33
34
                def __enter__(self):
35
                    return self
36
37
                def __exit__(self, exc_type, exc_val, exc_tb):
38
                    pass
39
40
        # Set stflow to use the mock as the test
41
        import sacred.stflow.method_interception
42
        sacred.stflow.method_interception.tensorflow = tensorflow
43
        return tensorflow
44
45
46
# Tests whether logdir is stored into the info dictionary when creating a new FileWriter object
47 View Code Duplication
def test_log_file_writer(ex, tf):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
48
    TEST_LOG_DIR = "/dev/null"
49
    TEST_LOG_DIR2 = "/tmp/sacred_test"
50
51
    @ex.main
52
    @LogFileWriter(ex)
53
    def run_experiment(_run):
54
        assert _run.info.get("tensorflow", None) is None
55
        with tf.Session() as s:
56
            with LogFileWriter(ex):
57
                swr = tf.summary.FileWriter(logdir=TEST_LOG_DIR, graph=s.graph)
58
            assert swr is not None
59
            assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR]
60
            tf.summary.FileWriter(TEST_LOG_DIR2, s.graph)
61
            assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR, TEST_LOG_DIR2]
62
63
    ex.run()
64
65
66
def test_log_summary_writer_as_context_manager(ex, tf):
67
    """ Check that Tensorflow log directory is captured by LogFileWriter context manager"""
68
    TEST_LOG_DIR = "/dev/null"
69
    TEST_LOG_DIR2 = "/tmp/sacred_test"
70
71
    @ex.main
72
    def run_experiment(_run):
73
        assert _run.info.get("tensorflow", None) is None
74
        with tf.Session() as s:
75
            # Without using the LogFileWriter context manager, nothing should change
76
            swr = tf.summary.FileWriter(logdir=TEST_LOG_DIR, graph=s.graph)
77
            assert swr is not None
78
            assert _run.info.get("tensorflow", None) is None
79
80
            # Capturing the log directory should be done only in scope of the context manager
81
            with LogFileWriter(ex):
82
                swr = tf.summary.FileWriter(logdir=TEST_LOG_DIR, graph=s.graph)
83
                assert swr is not None
84
                assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR]
85
                tf.summary.FileWriter(TEST_LOG_DIR2, s.graph)
86
                assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR, TEST_LOG_DIR2]
87
88
            # This should not be captured:
89
            tf.summary.FileWriter("/tmp/whatever", s.graph)
90
            assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR, TEST_LOG_DIR2]
91
92
    ex.run()
93
94 View Code Duplication
def test_log_file_writer_as_context_manager_with_exception(ex, tf):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
95
    """ Check that Tensorflow log directory is captured by LogFileWriter context manager"""
96
    TEST_LOG_DIR = "/tmp/sacred_test"
97
98
    @ex.main
99
    def run_experiment(_run):
100
        assert _run.info.get("tensorflow", None) is None
101
        with tf.Session() as s:
102
            # Capturing the log directory should be done only in scope of the context manager
103
            try:
104
                with LogFileWriter(ex):
105
                    swr = tf.summary.FileWriter(logdir=TEST_LOG_DIR, graph=s.graph)
106
                    assert swr is not None
107
                    assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR]
108
                    raise ValueError("I want to be raised!")
109
            except ValueError:
110
                pass
111
            # This should not be captured:
112
            tf.summary.FileWriter("/tmp/whatever", s.graph)
113
            assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR]
114
115
    ex.run()
116
117
# Tests whether logdir is stored into the info dictionary when creating a new FileWriter object,
118
# but this time on a method of a class
119
def test_log_summary_writer_class(ex, tf):
120
    TEST_LOG_DIR = "/dev/null"
121
    TEST_LOG_DIR2 = "/tmp/sacred_test"
122
123
    class FooClass():
124
        def __init__(self):
125
            pass
126
127
        @LogFileWriter(ex)
128
        def hello(self, argument):
129
            with tf.Session() as s:
130
                tf.summary.FileWriter(argument, s.graph)
131
132
    @ex.main
133
    def run_experiment(_run):
134
        assert _run.info.get("tensorflow", None) is None
135
        foo = FooClass()
136
        with tf.Session() as s:
137
            swr = tf.summary.FileWriter(TEST_LOG_DIR, s.graph)
138
            assert swr is not None
139
            # Because FileWriter was not called in an annotated function
140
            assert _run.info.get("tensorflow", None) is None
141
        foo.hello(TEST_LOG_DIR2)
142
        # Because foo.hello was anotated
143
        assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR2]
144
145
        with tf.Session() as s:
146
            swr = tf.summary.FileWriter(TEST_LOG_DIR, s.graph)
147
            # Nothing should be added, because FileWriter was again not called in an annotated function
148
            assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR2]
149
150
    ex.run()
151
152
if __name__ == "__main__":
153
    test_log_file_writer(ex(), tf())