Session.__init__()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 2
rs 10
1
# -*- coding: utf8 -*-
2
import pytest
3
4
from sacred import Experiment
5
from sacred.stflow import LogFileWriter
6
7
8
@pytest.fixture
9
def ex():
10
    return Experiment('tensorflow_tests')
11
12
13
# Creates a simplified tensorflow interface if necessary
14
# so tensorflow is not required during the tests
15
@pytest.fixture()
16
def tf():
17
    from sacred.optional import has_tensorflow
18
    if has_tensorflow:
19
        import tensorflow
20
        return tensorflow
21
    else:
22
        # Let's define a mocked tensorflow
23
        class tensorflow():
24
            class summary():
25
                class FileWriter():
26
                    def __init__(self, logdir, graph):
27
                        self.logdir = logdir
28
                        self.graph = graph
29
                        print("Mocked FileWriter got logdir=%s, graph=%s" % (logdir, graph))
30
31
            class Session():
32
                def __init__(self):
33
                    self.graph = None
34
35
                def __enter__(self):
36
                    return self
37
38
                def __exit__(self, exc_type, exc_val, exc_tb):
39
                    pass
40
41
        # Set stflow to use the mock as the test
42
        import sacred.stflow.method_interception
43
        sacred.stflow.method_interception.tensorflow = tensorflow
44
        return tensorflow
45
46
47
# Tests whether logdir is stored into the info dictionary when creating a new FileWriter object
48 View Code Duplication
def test_log_file_writer(ex, tf):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
49
    TEST_LOG_DIR = "/dev/null"
50
    TEST_LOG_DIR2 = "/tmp/sacred_test"
51
52
    @ex.main
53
    @LogFileWriter(ex)
54
    def run_experiment(_run):
55
        assert _run.info.get("tensorflow", None) is None
56
        with tf.Session() as s:
57
            with LogFileWriter(ex):
58
                swr = tf.summary.FileWriter(logdir=TEST_LOG_DIR, graph=s.graph)
59
            assert swr is not None
60
            assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR]
61
            tf.summary.FileWriter(TEST_LOG_DIR2, s.graph)
62
            assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR, TEST_LOG_DIR2]
63
64
    ex.run()
65
66
67
def test_log_summary_writer_as_context_manager(ex, tf):
68
    """ Check that Tensorflow log directory is captured by LogFileWriter context manager"""
69
    TEST_LOG_DIR = "/dev/null"
70
    TEST_LOG_DIR2 = "/tmp/sacred_test"
71
72
    @ex.main
73
    def run_experiment(_run):
74
        assert _run.info.get("tensorflow", None) is None
75
        with tf.Session() as s:
76
            # Without using the LogFileWriter context manager, nothing should change
77
            swr = tf.summary.FileWriter(logdir=TEST_LOG_DIR, graph=s.graph)
78
            assert swr is not None
79
            assert _run.info.get("tensorflow", None) is None
80
81
            # Capturing the log directory should be done only in scope of the context manager
82
            with LogFileWriter(ex):
83
                swr = tf.summary.FileWriter(logdir=TEST_LOG_DIR, graph=s.graph)
84
                assert swr is not None
85
                assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR]
86
                tf.summary.FileWriter(TEST_LOG_DIR2, s.graph)
87
                assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR, TEST_LOG_DIR2]
88
89
            # This should not be captured:
90
            tf.summary.FileWriter("/tmp/whatever", s.graph)
91
            assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR, TEST_LOG_DIR2]
92
93
    ex.run()
94
95 View Code Duplication
def test_log_file_writer_as_context_manager_with_exception(ex, tf):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
96
    """ Check that Tensorflow log directory is captured by LogFileWriter context manager"""
97
    TEST_LOG_DIR = "/tmp/sacred_test"
98
99
    @ex.main
100
    def run_experiment(_run):
101
        assert _run.info.get("tensorflow", None) is None
102
        with tf.Session() as s:
103
            # Capturing the log directory should be done only in scope of the context manager
104
            try:
105
                with LogFileWriter(ex):
106
                    swr = tf.summary.FileWriter(logdir=TEST_LOG_DIR, graph=s.graph)
107
                    assert swr is not None
108
                    assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR]
109
                    raise ValueError("I want to be raised!")
110
            except ValueError:
111
                pass
112
            # This should not be captured:
113
            tf.summary.FileWriter("/tmp/whatever", s.graph)
114
            assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR]
115
116
    ex.run()
117
118
# Tests whether logdir is stored into the info dictionary when creating a new FileWriter object,
119
# but this time on a method of a class
120
def test_log_summary_writer_class(ex, tf):
121
    TEST_LOG_DIR = "/dev/null"
122
    TEST_LOG_DIR2 = "/tmp/sacred_test"
123
124
    class FooClass():
125
        def __init__(self):
126
            pass
127
128
        @LogFileWriter(ex)
129
        def hello(self, argument):
130
            with tf.Session() as s:
131
                tf.summary.FileWriter(argument, s.graph)
132
133
    @ex.main
134
    def run_experiment(_run):
135
        assert _run.info.get("tensorflow", None) is None
136
        foo = FooClass()
137
        with tf.Session() as s:
138
            swr = tf.summary.FileWriter(TEST_LOG_DIR, s.graph)
139
            assert swr is not None
140
            # Because FileWriter was not called in an annotated function
141
            assert _run.info.get("tensorflow", None) is None
142
        foo.hello(TEST_LOG_DIR2)
143
        # Because foo.hello was anotated
144
        assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR2]
145
146
        with tf.Session() as s:
147
            swr = tf.summary.FileWriter(TEST_LOG_DIR, s.graph)
148
            # Nothing should be added, because FileWriter was again not called in an annotated function
149
            assert _run.info["tensorflow"]["logdirs"] == [TEST_LOG_DIR2]
150
151
    ex.run()
152
153
if __name__ == "__main__":
154
    test_log_file_writer(ex(), tf())
155