Passed
Push — master ( c3d045...3525ae )
by Konstantinos
01:55 queued 43s
created

TensorflowSessionRunner.run()   A

Complexity

Conditions 2

Size

Total Lines 21
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 10
nop 3
dl 0
loc 21
rs 9.9
c 0
b 0
f 0
1
2
from typing import List
3
import tensorflow as tf
4
5
from software_patterns import Proxy
6
7
8
class TensorflowSessionRunnerSubject:
9
    def __init__(self, interactive_session) -> None:
10
        self.interactive_session = interactive_session
11
12
    def run(self, *args, **kwargs):
13
        return self.interactive_session.run(*args, **kwargs)
14
15
16
class TensorflowSessionRunner(Proxy):
17
    def __init__(self, real_subject) -> None:
18
        super().__init__(real_subject)
19
        # self._proxy_subject IS a reference to an
20
        # TensorflowSessionRun  nerSubject instance
21
        self.args_history: List[str] = []
22
23
    def run(self, *args, **kwargs):
24
        """# Using the `close()` method.
25
        sess = tf.compat.v1.Session()
26
        sess.run(...)
27
        sess.close()
28
29
        OR
30
31
        # Using the context manager.
32
        with tf.compat.v1.Session() as sess:
33
        sess.run(...)
34
        """        
35
        session_run_callable = self._proxy_subject.run
36
        args_str = f"[{', '.join((str(_) for _ in args))}]"
37
        kwargs_str = f"[{', '.join((f'{k}={v}' for k, v in kwargs.items()))}]"
38
        self.args_history.append(f"ARGS: {args_str}, KWARGS: {kwargs_str}")
39
        try:
40
            return session_run_callable(*args, **kwargs)
41
        except Exception as tensorflow_error:
42
            raise TensorflowSessionRunError('Tensorflow error occured, when'
43
            f'running session with input args {args_str} and kwargs {kwargs_str}') from tensorflow_error        
44
45
    @property
46
    def session(self):
47
        return self._proxy_subject.interactive_session
48
49
    @classmethod
50
    def with_default_graph_reset(cls):
51
        tf.compat.v1.reset_default_graph()
52
        tf.compat.v1.disable_eager_execution()
53
        return TensorflowSessionRunner(TensorflowSessionRunnerSubject(
54
            tf.compat.v1.InteractiveSession()))
55
56
57
class TensorflowSessionRunError(Exception): pass
58