Total Complexity | 12 |
Total Lines | 60 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | import time |
||
2 | from redis import Redis |
||
3 | from flask import current_app |
||
4 | |||
5 | redis = None |
||
6 | |||
7 | |||
8 | class FakeRedis: |
||
9 | """Redis mock used for testing.""" |
||
10 | def __init__(self): |
||
11 | self.v = {} |
||
12 | self.last_key = None |
||
13 | |||
14 | def pipeline(self): |
||
15 | return self |
||
16 | |||
17 | def incr(self, key): |
||
18 | if self.v.get(key, None) is None: |
||
19 | self.v[key] = 0 |
||
20 | self.v[key] += 1 |
||
21 | self.last_key = key |
||
22 | |||
23 | def expireat(self, key, exp_time): |
||
24 | pass |
||
25 | |||
26 | def execute(self): |
||
27 | return [self.v[self.last_key]] |
||
28 | |||
29 | |||
30 | class RateLimit: |
||
31 | expiration_window = 10 |
||
32 | |||
33 | def __init__(self, key_prefix, limit, per): |
||
34 | global redis |
||
35 | if redis is None and current_app.config['USE_RATE_LIMITS']: |
||
36 | if current_app.config['TESTING']: |
||
37 | redis = FakeRedis() |
||
38 | else: # pragma: no cover |
||
39 | redis_host = current_app.config.get("REDIS_HOST", "localhost") |
||
40 | redis_port = current_app.config.get("REDIS_PORT", 6379) |
||
41 | redis_db = current_app.config.get("REDIS_DB", 0) |
||
42 | redis = Redis(host=redis_host, port=redis_port, db=redis_db) |
||
43 | |||
44 | self.reset = (int(time.time()) // per) * per + per |
||
45 | self.key = key_prefix + str(self.reset) |
||
46 | self.limit = limit |
||
47 | self.per = per |
||
48 | p = redis.pipeline() |
||
49 | p.incr(self.key) |
||
50 | p.expireat(self.key, self.reset + self.expiration_window) |
||
51 | self.current = min(p.execute()[0], limit) |
||
52 | |||
53 | @property |
||
54 | def remaining(self): |
||
55 | return self.limit - self.current |
||
56 | |||
57 | @property |
||
58 | def over_limit(self): |
||
59 | return self.current >= self.limit |
||
60 |