Passed
Push — master ( ee1e78...515b92 )
by Konstantinos
01:14
created

artificial_artwork.tf_session_runner   A

Complexity

Total Complexity 8

Size/Duplication

Total Lines 42
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 30
dl 0
loc 42
rs 10
c 0
b 0
f 0
wmc 8

7 Methods

Rating   Name   Duplication   Size   Complexity  
A TensorflowSessionRunner.__init__() 0 3 1
A TensorflowSessionRunner.with_default_graph_reset() 0 6 1
A TensorflowSessionRunnerSubject.request() 0 2 1
A TensorflowSessionRunnerSubject.__init__() 0 2 1
A TensorflowSessionRunner.run() 0 2 1
A TensorflowSessionRunner.request() 0 7 2
A TensorflowSessionRunner.session() 0 3 1
1
2
from typing import List
3
import tensorflow as tf
4
5
from .utils.proxy import RealSubject, Proxy
6
7
8
class TensorflowSessionRunnerSubject(RealSubject):
9
    def __init__(self, interactive_session) -> None:
10
        self.interactive_session = interactive_session
11
12
    def request(self, *args, **kwargs):
13
        return self.interactive_session.run(*args, **kwargs)
14
15
16
class TensorflowSessionRunner(Proxy):
17
    def __init__(self, real_subject) -> None:
18
        super().__init__(real_subject)
19
        self.args_history: List[str] = []
20
21
    def request(self, *args, **kwargs):
22
        self.args_history.append(f"ARGS: [{', '.join((str(_) for _ in args))}], KWARGS: [{', '.join((f'{k}={v}' for k, v in kwargs.items()))}]")
23
        try:
24
        # We know that Proxy executes request by executing the request method on the subject
25
            return super().request(*args, **kwargs)
26
        except Exception as e:
27
            raise e
28
29
    @property
30
    def session(self):
31
        return self._real_subject.interactive_session
32
33
    def run(self, *args, **kwargs):
34
        return self.request(*args, **kwargs)
35
36
    @classmethod
37
    def with_default_graph_reset(cls):
38
        tf.compat.v1.reset_default_graph()
39
        tf.compat.v1.disable_eager_execution()
40
        return TensorflowSessionRunner(TensorflowSessionRunnerSubject(
41
            tf.compat.v1.InteractiveSession()))
42