1
|
|
|
import importlib |
2
|
|
|
import os |
3
|
|
|
from collections import defaultdict |
4
|
|
|
|
5
|
|
|
import ZODB |
6
|
|
|
import ZODB.FileStorage |
7
|
|
|
import arrow |
8
|
|
|
import persistent |
9
|
|
|
import transaction |
10
|
|
|
import zodburi |
11
|
|
|
|
12
|
|
|
import config |
13
|
|
|
from log import logger |
14
|
|
|
|
15
|
|
|
|
16
|
|
|
class User(persistent.Persistent): |
17
|
|
|
def __init__(self, user_id): |
18
|
|
|
self.id = user_id |
19
|
|
|
|
20
|
|
|
def update(self, first_name, last_name, username): |
21
|
|
|
self.first_name = first_name or '' |
22
|
|
|
self.last_name = last_name or '' |
23
|
|
|
self.username = username or '' |
24
|
|
|
self._p_changed__ = True |
25
|
|
|
|
26
|
|
|
|
27
|
|
|
class Entry(object): |
28
|
|
|
def __init__(self, message_id, user_id, amount, date, reason=None): |
29
|
|
|
self.message_id = message_id |
30
|
|
|
self.user_id = user_id |
31
|
|
|
self.amount = amount |
32
|
|
|
self.reason = 'stuff' if not reason else reason |
33
|
|
|
self.date = date |
34
|
|
|
|
35
|
|
|
|
36
|
|
|
class Tab(persistent.Persistent): |
37
|
|
|
def __init__(self, chat_id): |
38
|
|
|
self.chat_id = chat_id |
39
|
|
|
self.grandtotal = 0 |
40
|
|
|
self.entries = [] |
41
|
|
|
self.tz = 'UTC' |
42
|
|
|
self.users = defaultdict(int) |
43
|
|
|
|
44
|
|
|
def clear(self): |
45
|
|
|
self.entries = [] |
46
|
|
|
self.grandtotal = 0 |
47
|
|
|
self.users = defaultdict(int) |
48
|
|
|
self._p_changed__ = True |
49
|
|
|
|
50
|
|
|
def set_timezone(self, tz): |
51
|
|
|
self.tz = tz |
52
|
|
|
|
53
|
|
|
def remove(self, message_id, user_id, date, amount, reason=None): |
54
|
|
|
return self.add(message_id, user_id, date, -1 * amount, reason) |
55
|
|
|
|
56
|
|
|
def register_user(self, user_id): |
57
|
|
|
self.users[user_id] |
58
|
|
|
|
59
|
|
|
def add(self, message_id, user_id, date, amount, reason=''): |
60
|
|
|
position = 0 |
61
|
|
|
for position, v in enumerate(self.entries): |
62
|
|
|
if v.message_id == message_id: |
63
|
|
|
# Already in list, ignore |
64
|
|
|
logger.debug('not adding {}, already in list'.format(amount)) |
65
|
|
|
return |
66
|
|
|
elif v.message_id < message_id: |
67
|
|
|
break |
68
|
|
|
|
69
|
|
|
date = arrow.get(date).to(self.tz) |
70
|
|
|
entry = Entry(message_id, user_id, amount, date, reason) |
71
|
|
|
logger.debug('adding {}'.format(amount)) |
72
|
|
|
self.entries.insert(position, entry) |
73
|
|
|
|
74
|
|
|
self.grandtotal += amount |
75
|
|
|
self.users[user_id] += amount |
76
|
|
|
|
77
|
|
|
self._p_changed__ = True |
78
|
|
|
|
79
|
|
|
def get_entries(self, from_date=None, to_date=None): |
80
|
|
|
if from_date and to_date and from_date > to_date: |
81
|
|
|
return [] |
82
|
|
|
|
83
|
|
|
entries = [] |
84
|
|
|
for entry in self.entries: |
85
|
|
|
if to_date and to_date < entry.date: |
86
|
|
|
continue |
87
|
|
|
|
88
|
|
|
if from_date: |
89
|
|
|
if entry.date >= from_date: |
90
|
|
|
entries.append(entry) |
91
|
|
|
else: |
92
|
|
|
break |
93
|
|
|
|
94
|
|
|
return entries |
95
|
|
|
|
96
|
|
|
def get_total(self, from_date=None, to_date=None): |
97
|
|
|
if not from_date and not to_date: |
98
|
|
|
return self.grandtotal |
99
|
|
|
|
100
|
|
|
if from_date and to_date and from_date > to_date: |
101
|
|
|
return -1 |
102
|
|
|
|
103
|
|
|
total = 0 |
104
|
|
|
entries = self.get_entries(from_date, to_date) |
105
|
|
|
for entry in entries: |
106
|
|
|
total += entry.amount |
107
|
|
|
|
108
|
|
|
return total |
109
|
|
|
|
110
|
|
|
|
111
|
|
|
class DB(object): |
112
|
|
|
|
113
|
|
|
def __init__(self): |
114
|
|
|
# Setup DB |
115
|
|
|
storage = zodburi.resolve_uri(config.database_url)[0]() |
116
|
|
|
self._db = ZODB.DB(storage) |
117
|
|
|
self._connection = self._db.open() |
118
|
|
|
self.root = self._connection.root |
119
|
|
|
self.migrate() |
120
|
|
|
|
121
|
|
|
def migrate(self): |
122
|
|
|
for migration_file in sorted(os.listdir('migrations')): |
123
|
|
|
if migration_file.endswith('.pyc') or not migration_file.startswith('migration_'): |
124
|
|
|
continue |
125
|
|
|
mod = importlib.import_module('migrations.{}'.format(migration_file[:-3])) |
126
|
|
|
if not mod.Migration.is_applicable(self.root): |
127
|
|
|
logger.debug('Skipping migration {}'.format(mod.Migration.DB_VERSION)) |
128
|
|
|
continue |
129
|
|
|
migration = mod.Migration(self.root) |
130
|
|
|
logger.info('Applying migration {}'.format(migration.DB_VERSION)) |
131
|
|
|
migration.apply() |
132
|
|
|
|
133
|
|
|
def get_or_create_tab(self, tab_id): |
|
|
|
|
134
|
|
|
if tab_id in self.root.tabs: |
135
|
|
|
return self.root.tabs[tab_id], False |
136
|
|
|
|
137
|
|
|
tab = Tab(tab_id) |
138
|
|
|
self.root.tabs[tab_id] = tab |
139
|
|
|
self.root.stats['number_of_tabs'] += 1 |
140
|
|
|
logger.debug('Created tab {}'.format(tab_id)) |
141
|
|
|
return tab, True |
142
|
|
|
|
143
|
|
|
def get_or_create_user(self, user_id): |
|
|
|
|
144
|
|
|
if user_id in self.root.users: |
145
|
|
|
return self.root.users[user_id], False |
146
|
|
|
|
147
|
|
|
user = User(user_id) |
148
|
|
|
self.root.users[user_id] = user |
149
|
|
|
self.root.stats['number_of_users'] += 1 |
150
|
|
|
logger.debug('Created user {}'.format(user_id)) |
151
|
|
|
return user, True |
152
|
|
|
|
153
|
|
|
def commit(self): |
154
|
|
|
transaction.commit() |
155
|
|
|
|
156
|
|
|
def close(self): |
157
|
|
|
self._db.close() |
158
|
|
|
|
Duplicated code is one of the most pungent code smells. If you need to duplicate the same code in three or more different places, we strongly encourage you to look into extracting the code into a single class or operation.
You can also find more detailed suggestions in the “Code” section of your repository.