BaseDB.escape()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
#!/usr/bin/env python
2
# -*- encoding: utf-8 -*-
3
# vim: set et sw=4 ts=4 sts=4 ff=unix fenc=utf8:
4
# Author: Binux<[email protected]>
5
#         http://binux.me
6
# Created on 2012-08-30 17:43:49
7
8
from __future__ import unicode_literals, division, absolute_import
9
10
import logging
11
logger = logging.getLogger('database.basedb')
12
13
from six import itervalues
14
from pyspider.libs import utils
15
16
17
class BaseDB:
18
19
    '''
20
    BaseDB
21
22
    dbcur should be overwirte
23
    '''
24
    __tablename__ = None
25
    placeholder = '%s'
26
    maxlimit = -1
27
28
    @staticmethod
29
    def escape(string):
30
        return '`%s`' % string
31
32
    @property
33
    def dbcur(self):
34
        raise NotImplementedError
35
36
    def _execute(self, sql_query, values=[]):
37
        dbcur = self.dbcur
38
        dbcur.execute(sql_query, values)
39
        return dbcur
40
41
    def _select(self, tablename=None, what="*", where="", where_values=[], offset=0, limit=None):
42
        tablename = self.escape(tablename or self.__tablename__)
43
        if isinstance(what, list) or isinstance(what, tuple) or what is None:
44
            what = ','.join(self.escape(f) for f in what) if what else '*'
45
46
        sql_query = "SELECT %s FROM %s" % (what, tablename)
47
        if where:
48
            sql_query += " WHERE %s" % where
49
        if limit:
50
            sql_query += " LIMIT %d, %d" % (offset, limit)
51
        elif offset:
52
            sql_query += " LIMIT %d, %d" % (offset, self.maxlimit)
53
        logger.debug("<sql: %s>", sql_query)
54
55
        for row in self._execute(sql_query, where_values):
56
            yield row
57
58
    def _select2dic(self, tablename=None, what="*", where="", where_values=[],
59
                    order=None, offset=0, limit=None):
60
        tablename = self.escape(tablename or self.__tablename__)
61
        if isinstance(what, list) or isinstance(what, tuple) or what is None:
62
            what = ','.join(self.escape(f) for f in what) if what else '*'
63
64
        sql_query = "SELECT %s FROM %s" % (what, tablename)
65
        if where:
66
            sql_query += " WHERE %s" % where
67
        if order:
68
            sql_query += ' ORDER BY %s' % order
69
        if limit:
70
            sql_query += " LIMIT %d, %d" % (offset, limit)
71
        elif offset:
72
            sql_query += " LIMIT %d, %d" % (offset, self.maxlimit)
73
        logger.debug("<sql: %s>", sql_query)
74
75
        dbcur = self._execute(sql_query, where_values)
76
77
        # f[0] may return bytes type
78
        # https://github.com/mysql/mysql-connector-python/pull/37
79
        fields = [utils.text(f[0]) for f in dbcur.description]
80
81
        for row in dbcur:
82
            yield dict(zip(fields, row))
83
84
    def _replace(self, tablename=None, **values):
85
        tablename = self.escape(tablename or self.__tablename__)
86
        if values:
87
            _keys = ", ".join(self.escape(k) for k in values)
88
            _values = ", ".join([self.placeholder, ] * len(values))
89
            sql_query = "REPLACE INTO %s (%s) VALUES (%s)" % (tablename, _keys, _values)
90
        else:
91
            sql_query = "REPLACE INTO %s DEFAULT VALUES" % tablename
92
        logger.debug("<sql: %s>", sql_query)
93
94
        if values:
95
            dbcur = self._execute(sql_query, list(itervalues(values)))
96
        else:
97
            dbcur = self._execute(sql_query)
98
        return dbcur.lastrowid
99
100
    def _insert(self, tablename=None, **values):
101
        tablename = self.escape(tablename or self.__tablename__)
102
        if values:
103
            _keys = ", ".join((self.escape(k) for k in values))
104
            _values = ", ".join([self.placeholder, ] * len(values))
105
            sql_query = "INSERT INTO %s (%s) VALUES (%s)" % (tablename, _keys, _values)
106
        else:
107
            sql_query = "INSERT INTO %s DEFAULT VALUES" % tablename
108
        logger.debug("<sql: %s>", sql_query)
109
110
        if values:
111
            dbcur = self._execute(sql_query, list(itervalues(values)))
112
        else:
113
            dbcur = self._execute(sql_query)
114
        return dbcur.lastrowid
115
116
    def _update(self, tablename=None, where="1=0", where_values=[], **values):
117
        tablename = self.escape(tablename or self.__tablename__)
118
        _key_values = ", ".join([
119
            "%s = %s" % (self.escape(k), self.placeholder) for k in values
120
        ])
121
        sql_query = "UPDATE %s SET %s WHERE %s" % (tablename, _key_values, where)
122
        logger.debug("<sql: %s>", sql_query)
123
124
        return self._execute(sql_query, list(itervalues(values)) + list(where_values))
125
126
    def _delete(self, tablename=None, where="1=0", where_values=[]):
127
        tablename = self.escape(tablename or self.__tablename__)
128
        sql_query = "DELETE FROM %s" % tablename
129
        if where:
130
            sql_query += " WHERE %s" % where
131
        logger.debug("<sql: %s>", sql_query)
132
133
        return self._execute(sql_query, where_values)
134
135
if __name__ == "__main__":
136
    import sqlite3
137
138
    class DB(BaseDB):
139
        __tablename__ = "test"
140
        placeholder = "?"
141
142
        def __init__(self):
143
            self.conn = sqlite3.connect(":memory:")
144
            cursor = self.conn.cursor()
145
            cursor.execute(
146
                '''CREATE TABLE `%s` (id INTEGER PRIMARY KEY AUTOINCREMENT, name, age)'''
147
                % self.__tablename__
148
            )
149
150
        @property
151
        def dbcur(self):
152
            return self.conn.cursor()
153
154
    db = DB()
155
    assert db._insert(db.__tablename__, name="binux", age=23) == 1
156
    assert db._select(db.__tablename__, "name, age").next() == ("binux", 23)
157
    assert db._select2dic(db.__tablename__, "name, age").next()["name"] == "binux"
158
    assert db._select2dic(db.__tablename__, "name, age").next()["age"] == 23
159
    db._replace(db.__tablename__, id=1, age=24)
160
    assert db._select(db.__tablename__, "name, age").next() == (None, 24)
161
    db._update(db.__tablename__, "id = 1", age=16)
162
    assert db._select(db.__tablename__, "name, age").next() == (None, 16)
163
    db._delete(db.__tablename__, "id = 1")
164
    assert [row for row in db._select(db.__tablename__)] == []
165