Completed
Push — master ( 5644b6...2196d0 )
by Daniel
01:17
created

DatabaseClass.get()   A

Complexity

Conditions 3

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 3
c 1
b 0
f 0
dl 0
loc 5
rs 9.4285
1
import logging
2
3
from sqlalchemy import create_engine
4
from sqlalchemy.orm import scoped_session, sessionmaker
5
from sqlalchemy.ext.declarative import declarative_base
6
7
from groundwork.patterns import GwBasePattern
8
9
10
class GwSqlPattern(GwBasePattern):
11
    """
12
13
    """
14
    def __init__(self, *args, **kwargs):
15
        super().__init__(*args, **kwargs)
16
        if not hasattr(self.app, "databases"):
17
            self.app.databases = SqlDatabasesApplication(self.app)
18
19
        #: Instance of :class:`~.SqlDatabasesPlugin`.
20
        #: Provides functions to register and manage sql database interfaces
21
        self.databases = SqlDatabasesPlugin(self)
22
23
24
class SqlDatabasesPlugin:
25
    def __init__(self, plugin):
26
        self.plugin = plugin
27
        self.app = plugin.app
28
        self.log = plugin.log
29
30
        # Let's register a receiver, which cares about the deactivation process of sql databases for this plugin.
31
        # We do it after the original plugin deactivation, so we can be sure that the registered function is the last
32
        # one which cares about sql databases for this plugin.
33
        self.plugin.signals.connect(receiver="%s_sql_deactivation" % self.plugin.name,
34
                                    signal="plugin_deactivate_post",
35
                                    function=self.__deactivate_sql_databases,
36
                                    description="Deactivates sql databases for %s" % self.plugin.name,
37
                                    sender=self.plugin)
38
        self.log.debug("Pattern sql databases initialised")
39
40
    def __deactivate_sql_databases(self, plugin, *args, **kwargs):
41
        databases = self.get()
42
        for databases in databases.keys():
43
            self.unregister(databases)
44
45
    def register(self, database, database_url, description):
46
        """
47
        Registers a new sql database for a plugin.
48
        """
49
        return self.app.databases.register(database, database_url, description, self.plugin)
50
51
    def unregister(self, database):
52
        """
53
        Unregisters an existing database, so that this database is no longer.
54
        This function is mainly used during plugin deactivation.
55
        """
56
        return self.app.databases.unregister(database)
57
58
    def get(self, name=None):
59
        """
60
        Returns databases, which can be filtered by name.
61
62
        :param name: name of the database
63
        :type name: str
64
        :return: None, single database or dict of databases
65
        """
66
        return self.app.databases.get(name, self.plugin)
67
68
69
class SqlDatabasesApplication:
70
    def __init__(self, app):
71
        self.app = app
72
        self.log = logging.getLogger(__name__)
73
        self._databases = {}
74
        self.log.info("Application sql databases initialised")
75
76
    def register(self, database, database_url, description, plugin=None):
77
        """
78
        Registers a new sql database for a plugin.
79
        """
80
        if database in self._databases.keys():
81
            raise DatabaseExistException("Database %s already registered by %s" % (
82
                database, self._databases[database].plugin.name))
83
84
        new_database = Database(database, database_url, description, plugin)
85
        self._databases[database] = new_database
86
        self.log.debug("Database registered: %s" % database)
87
        return new_database
88
89
    def unregister(self, database):
90
        """
91
        Unregisters an existing database, so that this database is no longer.
92
        This function is mainly used during plugin deactivation.
93
        """
94
        if database not in self._databases.keys():
95
            self.log.warning("Can not unregister database %s. Reason: Database does not exist." % database)
96
        else:
97
            del(self._databases[database])
98
            self.log.debug("Database %s git unregistered" % database)
99
100
    def get(self, name=None, plugin=None):
101
        """
102
        Returns databases, which can be filtered by name.
103
104
        :param name: name of the database
105
        :type name: str
106
        :return: None, single database or dict of databases
107
        """
108
        if plugin is not None:
109
            if name is None:
110
                database_list = {}
111
                for key in self._databases.keys():
112
                    if self._databases[key].plugin == plugin:
113
                        database_list[key] = self._databases[key]
114
                return database_list
115
            else:
116
                if name in self._databases.keys():
117
                    if self._databases[name].plugin == plugin:
118
                        return self._databases[name]
119
                    else:
120
                        return None
121
                else:
122
                    return None
123
        else:
124
            if name is None:
125
                return self._databases
126
            else:
127
                if name in self._databases.keys():
128
                    return self._databases[name]
129
                else:
130
                    return None
131
132
133
class Database:
134
    def __init__(self, name, url, description, plugin):
135
        self.name = name
136
        self.database_url = url
137
        self.description = description
138
        self.plugin = plugin
139
140
        self.engine = create_engine(url)
141
142
        self.session = scoped_session(sessionmaker(autocommit=False,
143
                                                   autoflush=False,
144
                                                   bind=self.engine))
145
        self.Base = declarative_base()
146
147
        # This allows to perform Class.query (e.g. User.query), which is normally not
148
        # available on pure sqlalchemy models. But this kind of usage is provided by libs like flask-sqlalachemy,
149
        # what makes it very handy to query classes.
150
        # Fore more visit: http://stackoverflow.com/a/28025843
151
        self.Base.query = self.session.query_property()
152
153
        self.classes = DatabaseClass(self.Base)
154
155
    def create_all(self):
156
        return self.Base.metadata.create_all(self.engine)
157
158
    def commit(self, *args, **kwargs):
159
        return self.session.commit(*args, **kwargs)
160
161
    def query(self, *args, **kwargs):
162
        return self.session.query(*args, **kwargs)
163
164
    def add(self, *args, **kwargs):
165
        return self.session.add(*args, **kwargs)
166
167
    def delete(self, *args, **kwargs):
168
        return self.session.delete(*args, **kwargs)
169
170
    def rollback(self, *args, **kwargs):
171
        return self.session.rollback(*args, **kwargs)
172
173
    def close(self, *args, **kwargs):
174
        return self.session.close(*args, **kwargs)
175
176
177
class DatabaseClass:
178
    def __init__(self, Base):
179
        self._Base = Base
180
        self._classes = {}
181
182
    def register(self, clazz, name=None):
183
        if name is None:
184
            name = clazz.__name__
185
186
        if name in self._classes.keys():
187
            raise DatabaseClassExistException("Database class %s already registered")
188
189
        # We need to "combine" the given user class with the Base class of our database.
190
        # Normally the user class inherits from this Base class.
191
        # But we need to do it dynamically and changing __bases__ of a class to add an inheritance does not work well.
192
        # Therefore we create a new class, which inherits from both (user class and Base class).
193
        # To not confusing developers during debug session, the new class gets the same name as the given user class.
194
        # TempClass = type(clazz.__name__, (self._Base, clazz), dict())
195
        # self._classes[name] = TempClass
196
        self._classes[name] = clazz
197
198
        return self._classes[name]
199
200
    def unregister(self, name):
201
        return self._classes.pop(name, None)
202
203
    def get(self, clazz_name=None):
204
        if clazz_name is not None and clazz_name in self._classes.keys():
205
            return self._classes[clazz_name]
206
        else:
207
            return None
208
209
210
class DatabaseExistException(BaseException):
211
    pass
212
213
214
class DatabaseClassExistException(BaseException):
215
    pass
216