Completed
Push — master ( e6dbce...addc19 )
by Roy
02:40
created

BaseDB   B

Complexity

Total Complexity 37

Size/Duplication

Total Lines 114
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 37
c 0
b 0
f 0
dl 0
loc 114
rs 8.6

9 Methods

Rating   Name   Duplication   Size   Complexity  
A dbcur() 0 3 1
A _update() 0 9 2
A escape() 0 3 1
A _execute() 0 4 1
A _delete() 0 8 2
C _select() 0 16 10
F _select2dic() 0 22 12
A _insert() 0 15 4
A _replace() 0 15 4
1
#!/usr/bin/env python
2
# -*- encoding: utf-8 -*-
3
# vim: set et sw=4 ts=4 sts=4 ff=unix fenc=utf8:
4
# Author: Binux<[email protected]>
5
#         http://binux.me
6
# Created on 2012-08-30 17:43:49
7
8
from __future__ import unicode_literals, division, absolute_import
9
10
import logging
11
logger = logging.getLogger('database.basedb')
12
13
from six import itervalues
14
15
16
class BaseDB:
17
18
    '''
19
    BaseDB
20
21
    dbcur should be overwirte
22
    '''
23
    __tablename__ = None
24
    placeholder = '%s'
25
    maxlimit = -1
26
27
    @staticmethod
28
    def escape(string):
29
        return '`%s`' % string
30
31
    @property
32
    def dbcur(self):
33
        raise NotImplementedError
34
35
    def _execute(self, sql_query, values=[]):
36
        dbcur = self.dbcur
37
        dbcur.execute(sql_query, values)
38
        return dbcur
39
40
    def _select(self, tablename=None, what="*", where="", where_values=[], offset=0, limit=None):
41
        tablename = self.escape(tablename or self.__tablename__)
42
        if isinstance(what, list) or isinstance(what, tuple) or what is None:
43
            what = ','.join(self.escape(f) for f in what) if what else '*'
44
45
        sql_query = "SELECT %s FROM %s" % (what, tablename)
46
        if where:
47
            sql_query += " WHERE %s" % where
48
        if limit:
49
            sql_query += " LIMIT %d, %d" % (offset, limit)
50
        elif offset:
51
            sql_query += " LIMIT %d, %d" % (offset, self.maxlimit)
52
        logger.debug("<sql: %s>", sql_query)
53
54
        for row in self._execute(sql_query, where_values):
55
            yield row
56
57
    def _select2dic(self, tablename=None, what="*", where="", where_values=[],
58
                    order=None, offset=0, limit=None):
59
        tablename = self.escape(tablename or self.__tablename__)
60
        if isinstance(what, list) or isinstance(what, tuple) or what is None:
61
            what = ','.join(self.escape(f) for f in what) if what else '*'
62
63
        sql_query = "SELECT %s FROM %s" % (what, tablename)
64
        if where:
65
            sql_query += " WHERE %s" % where
66
        if order:
67
            sql_query += ' ORDER BY %s' % order
68
        if limit:
69
            sql_query += " LIMIT %d, %d" % (offset, limit)
70
        elif offset:
71
            sql_query += " LIMIT %d, %d" % (offset, self.maxlimit)
72
        logger.debug("<sql: %s>", sql_query)
73
74
        dbcur = self._execute(sql_query, where_values)
75
        fields = [f[0] for f in dbcur.description]
76
77
        for row in dbcur:
78
            yield dict(zip(fields, row))
79
80
    def _replace(self, tablename=None, **values):
81
        tablename = self.escape(tablename or self.__tablename__)
82
        if values:
83
            _keys = ", ".join(self.escape(k) for k in values)
84
            _values = ", ".join([self.placeholder, ] * len(values))
85
            sql_query = "REPLACE INTO %s (%s) VALUES (%s)" % (tablename, _keys, _values)
86
        else:
87
            sql_query = "REPLACE INTO %s DEFAULT VALUES" % tablename
88
        logger.debug("<sql: %s>", sql_query)
89
90
        if values:
91
            dbcur = self._execute(sql_query, list(itervalues(values)))
92
        else:
93
            dbcur = self._execute(sql_query)
94
        return dbcur.lastrowid
95
96
    def _insert(self, tablename=None, **values):
97
        tablename = self.escape(tablename or self.__tablename__)
98
        if values:
99
            _keys = ", ".join((self.escape(k) for k in values))
100
            _values = ", ".join([self.placeholder, ] * len(values))
101
            sql_query = "INSERT INTO %s (%s) VALUES (%s)" % (tablename, _keys, _values)
102
        else:
103
            sql_query = "INSERT INTO %s DEFAULT VALUES" % tablename
104
        logger.debug("<sql: %s>", sql_query)
105
106
        if values:
107
            dbcur = self._execute(sql_query, list(itervalues(values)))
108
        else:
109
            dbcur = self._execute(sql_query)
110
        return dbcur.lastrowid
111
112
    def _update(self, tablename=None, where="1=0", where_values=[], **values):
113
        tablename = self.escape(tablename or self.__tablename__)
114
        _key_values = ", ".join([
115
            "%s = %s" % (self.escape(k), self.placeholder) for k in values
116
        ])
117
        sql_query = "UPDATE %s SET %s WHERE %s" % (tablename, _key_values, where)
118
        logger.debug("<sql: %s>", sql_query)
119
120
        return self._execute(sql_query, list(itervalues(values)) + list(where_values))
121
122
    def _delete(self, tablename=None, where="1=0", where_values=[]):
123
        tablename = self.escape(tablename or self.__tablename__)
124
        sql_query = "DELETE FROM %s" % tablename
125
        if where:
126
            sql_query += " WHERE %s" % where
127
        logger.debug("<sql: %s>", sql_query)
128
129
        return self._execute(sql_query, where_values)
130
131
if __name__ == "__main__":
132
    import sqlite3
133
134
    class DB(BaseDB):
135
        __tablename__ = "test"
136
        placeholder = "?"
137
138
        def __init__(self):
139
            self.conn = sqlite3.connect(":memory:")
140
            cursor = self.conn.cursor()
141
            cursor.execute(
142
                '''CREATE TABLE `%s` (id INTEGER PRIMARY KEY AUTOINCREMENT, name, age)'''
143
                % self.__tablename__
144
            )
145
146
        @property
147
        def dbcur(self):
148
            return self.conn.cursor()
149
150
    db = DB()
151
    assert db._insert(db.__tablename__, name="binux", age=23) == 1
152
    assert db._select(db.__tablename__, "name, age").next() == ("binux", 23)
153
    assert db._select2dic(db.__tablename__, "name, age").next()["name"] == "binux"
154
    assert db._select2dic(db.__tablename__, "name, age").next()["age"] == 23
155
    db._replace(db.__tablename__, id=1, age=24)
156
    assert db._select(db.__tablename__, "name, age").next() == (None, 24)
157
    db._update(db.__tablename__, "id = 1", age=16)
158
    assert db._select(db.__tablename__, "name, age").next() == (None, 16)
159
    db._delete(db.__tablename__, "id = 1")
160
    assert [row for row in db._select(db.__tablename__)] == []
161