Completed
Push — master ( ee223a...7e3034 )
by Ali-Akber
26s
created

FlaskExtTests.test_custom_key_prefix()   C

Complexity

Conditions 7

Size

Total Lines 35

Duplication

Lines 0
Ratio 0 %

Importance

Changes 2
Bugs 0 Features 0
Metric Value
cc 7
c 2
b 0
f 0
dl 0
loc 35
rs 5.5
1
"""
2
3
"""
4
import json
5
import logging
6
import time
7
import unittest
8
9
import hiro
10
import mock
11
import redis
12
import datetime
13
from flask import Flask, Blueprint, request, current_app, make_response
14
from flask_restful import Resource
15
from flask.views import View, MethodView
16
from limits.errors import ConfigurationError
17
from limits.storage import MemcachedStorage
18
from limits.strategies import MovingWindowRateLimiter
19
20
from flask_limiter.extension import C, Limiter, HEADERS
21
from flask_limiter.util import get_remote_address, get_ipaddr
22
23
24
class FlaskExtTests(unittest.TestCase):
25
    def setUp(self):
26
        redis.Redis().flushall()
27
28
    def build_app(self, config={}, **limiter_args):
29
        app = Flask(__name__)
30
        for k, v in config.items():
31
            app.config.setdefault(k, v)
32
        limiter_args.setdefault('key_func', get_remote_address)
33
        limiter = Limiter(app, **limiter_args)
34
        mock_handler = mock.Mock()
35
        mock_handler.level = logging.INFO
36
        limiter.logger.addHandler(mock_handler)
37
        return app, limiter
38
39
    def test_invalid_strategy(self):
40
        app = Flask(__name__)
41
        app.config.setdefault(C.STRATEGY, "fubar")
42
        self.assertRaises(ConfigurationError, Limiter, app, key_func=get_remote_address)
43
44
    def test_invalid_storage_string(self):
45
        app = Flask(__name__)
46
        app.config.setdefault(C.STORAGE_URL, "fubar://localhost:1234")
47
        self.assertRaises(ConfigurationError, Limiter, app, key_func=get_remote_address)
48
49
    def test_constructor_arguments_over_config(self):
50
        app = Flask(__name__)
51
        app.config.setdefault(C.STRATEGY, "fixed-window-elastic-expiry")
52
        limiter = Limiter(strategy='moving-window', key_func=get_remote_address)
53
        limiter.init_app(app)
54
        app.config.setdefault(C.STORAGE_URL, "redis://localhost:6379")
55
        self.assertEqual(type(limiter._limiter), MovingWindowRateLimiter)
56
        limiter = Limiter(storage_uri='memcached://localhost:11211', key_func=get_remote_address)
57
        limiter.init_app(app)
58
        self.assertEqual(type(limiter._storage), MemcachedStorage)
59
60
    def test_error_message(self):
61 View Code Duplication
        app, limiter = self.build_app({
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
62
            C.GLOBAL_LIMITS: "1 per day"
63
        })
64
65
        @app.route("/")
66
        def null():
67
            return ""
68
69
        with app.test_client() as cli:
70
            cli.get("/")
71
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
72
73
            @app.errorhandler(429)
74
            def ratelimit_handler(e):
75
                return make_response('{"error" : "rate limit %s"}' % str(e.description), 429)
76
77
            self.assertEqual({'error': 'rate limit 1 per 1 day'}, json.loads(cli.get("/").data.decode()))
78
79
    def test_reset(self):
80
        app, limiter = self.build_app({
81
            C.GLOBAL_LIMITS: "1 per day"
82
        })
83
84
        @app.route("/")
85
        def null():
86
            return "Hello Reset"
87
88
        with app.test_client() as cli:
89
            cli.get("/")
90
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
91
            limiter.reset()
92
            self.assertEqual("Hello Reset", cli.get("/").data.decode())
93
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
94
95
    def test_swallow_error(self):
96
        app, limiter = self.build_app({
97
            C.GLOBAL_LIMITS: "1 per day",
98
            C.SWALLOW_ERRORS: True
99
        })
100
101
        @app.route("/")
102
        def null():
103
            return "ok"
104
105
        with app.test_client() as cli:
106
            with mock.patch("limits.strategies.FixedWindowRateLimiter.hit") as hit:
107
                def raiser(*a, **k):
108
                    raise Exception
109
110
                hit.side_effect = raiser
111
                self.assertTrue("ok" in cli.get("/").data.decode())
112
113
    def test_no_swallow_error(self):
114 View Code Duplication
        app, limiter = self.build_app({
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
115
            C.GLOBAL_LIMITS: "1 per day",
116
        })
117
118
        @app.route("/")
119
        def null():
120
            return "ok"
121
122
        @app.errorhandler(500)
123
        def e500(e):
124
            return str(e), 500
125
126
        with app.test_client() as cli:
127
            with mock.patch("limits.strategies.FixedWindowRateLimiter.hit") as hit:
128
                def raiser(*a, **k):
129
                    raise Exception("underlying")
130
131
                hit.side_effect = raiser
132
                self.assertEqual(500, cli.get("/").status_code)
133
                self.assertEqual("underlying", cli.get("/").data.decode())
134
135
    def test_combined_rate_limits(self):
136 View Code Duplication
        app, limiter = self.build_app({
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
137
            C.GLOBAL_LIMITS: "1 per hour; 10 per day"
138
        })
139
140
        @app.route("/t1")
141
        @limiter.limit("100 per hour;10/minute")
142
        def t1():
143
            return "t1"
144
145
        @app.route("/t2")
146
        def t2():
147
            return "t2"
148
149
        with hiro.Timeline().freeze() as timeline:
150
            with app.test_client() as cli:
151
                self.assertEqual(200, cli.get("/t1").status_code)
152
                self.assertEqual(200, cli.get("/t2").status_code)
153
                self.assertEqual(429, cli.get("/t2").status_code)
154
155
    def test_key_func(self):
156 View Code Duplication
        app, limiter = self.build_app()
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
157
158
        @app.route("/t1")
159
        @limiter.limit("100 per minute", lambda: "test")
160
        def t1():
161
            return "test"
162
163
        with hiro.Timeline().freeze() as timeline:
164
            with app.test_client() as cli:
165
                for i in range(0, 100):
166
                    self.assertEqual(200,
167
                                     cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}).status_code
168
                                     )
169
                self.assertEqual(429, cli.get("/t1").status_code)
170
171
    def test_multiple_decorators(self):
172
        app, limiter = self.build_app(key_func=get_ipaddr)
173
174
        @app.route("/t1")
175
        @limiter.limit("100 per minute", lambda: "test")  # effectively becomes a limit for all users
176
        @limiter.limit("50/minute")  # per ip as per default key_func
177
        def t1():
178
            return "test"
179
180
        with hiro.Timeline().freeze() as timeline:
181
            with app.test_client() as cli:
182
                for i in range(0, 100):
183
                    self.assertEqual(200 if i < 50 else 429,
184
                                     cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.2"}).status_code
185
                                     )
186
                for i in range(50):
187
                    self.assertEqual(200, cli.get("/t1").status_code)
188
                self.assertEqual(429, cli.get("/t1").status_code)
189
                self.assertEqual(429, cli.get("/t1", headers={"X_FORWARDED_FOR": "127.0.0.3"}).status_code)
190
191
    def test_logging(self):
192 View Code Duplication
        app = Flask(__name__)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
193
        limiter = Limiter(app, key_func=get_remote_address)
194
        mock_handler = mock.Mock()
195
        mock_handler.level = logging.INFO
196
        limiter.logger.addHandler(mock_handler)
197
198
        @app.route("/t1")
199
        @limiter.limit("1/minute")
200
        def t1():
201
            return "test"
202
203
        with app.test_client() as cli:
204
            self.assertEqual(200, cli.get("/t1").status_code)
205
            self.assertEqual(429, cli.get("/t1").status_code)
206
        self.assertEqual(mock_handler.handle.call_count, 1)
207
208
    def test_reuse_logging(self):
209 View Code Duplication
        app = Flask(__name__)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
210
        app_handler = mock.Mock()
211
        app_handler.level = logging.INFO
212
        app.logger.addHandler(app_handler)
213
        limiter = Limiter(app, key_func=get_remote_address)
214
        for handler in app.logger.handlers:
215
            limiter.logger.addHandler(handler)
216
217
        @app.route("/t1")
218
        @limiter.limit("1/minute")
219
        def t1():
220
            return "42"
221
222
        with app.test_client() as cli:
223
            cli.get("/t1")
224
            cli.get("/t1")
225
226
        self.assertEqual(app_handler.handle.call_count, 1)
227
228
    def test_exempt_routes(self):
229 View Code Duplication
        app, limiter = self.build_app(default_limits=["1/minute"])
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
230
231
        @app.route("/t1")
232
        def t1():
233
            return "test"
234
235
        @app.route("/t2")
236
        @limiter.exempt
237
        def t2():
238
            return "test"
239
240
        with app.test_client() as cli:
241
            self.assertEqual(cli.get("/t1").status_code, 200)
242
            self.assertEqual(cli.get("/t1").status_code, 429)
243
            self.assertEqual(cli.get("/t2").status_code, 200)
244
            self.assertEqual(cli.get("/t2").status_code, 200)
245
246
    def test_blueprint(self):
247 View Code Duplication
        app, limiter = self.build_app(default_limits=["1/minute"])
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
248
        bp = Blueprint("main", __name__)
249
250
        @bp.route("/t1")
251
        def t1():
252
            return "test"
253
254
        @bp.route("/t2")
255
        @limiter.limit("10 per minute")
256
        def t2():
257
            return "test"
258
259
        app.register_blueprint(bp)
260
261
        with app.test_client() as cli:
262
            self.assertEqual(cli.get("/t1").status_code, 200)
263
            self.assertEqual(cli.get("/t1").status_code, 429)
264
            for i in range(0, 10):
265
                self.assertEqual(cli.get("/t2").status_code, 200)
266
            self.assertEqual(cli.get("/t2").status_code, 429)
267
268
    def test_register_blueprint(self):
269
        app, limiter = self.build_app(default_limits=["1/minute"])
270
        bp_1 = Blueprint("bp1", __name__)
271
        bp_2 = Blueprint("bp2", __name__)
272
        bp_3 = Blueprint("bp3", __name__)
273
        bp_4 = Blueprint("bp4", __name__)
274
275
        @bp_1.route("/t1")
276
        def t1():
277
            return "test"
278
279
        @bp_1.route("/t2")
280
        def t2():
281
            return "test"
282
283
        @bp_2.route("/t3")
284
        def t3():
285
            return "test"
286
287
        @bp_3.route("/t4")
288
        def t4():
289
            return "test"
290
291
        @bp_4.route("/t5")
292
        def t4():
293
            return "test"
294
295
        def dy_limit():
296
            return "1/second"
297
298
        app.register_blueprint(bp_1)
299
        app.register_blueprint(bp_2)
300
        app.register_blueprint(bp_3)
301
        app.register_blueprint(bp_4)
302
303
        limiter.limit("1/second")(bp_1)
304
        limiter.exempt(bp_3)
305
        limiter.limit(dy_limit)(bp_4)
306
307
        with hiro.Timeline().freeze() as timeline:
308
            with app.test_client() as cli:
309
                self.assertEqual(cli.get("/t1").status_code, 200)
310
                self.assertEqual(cli.get("/t1").status_code, 429)
311
                timeline.forward(1)
312
                self.assertEqual(cli.get("/t1").status_code, 200)
313
                self.assertEqual(cli.get("/t2").status_code, 200)
314
                self.assertEqual(cli.get("/t2").status_code, 429)
315
                timeline.forward(1)
316
                self.assertEqual(cli.get("/t2").status_code, 200)
317
318
                self.assertEqual(cli.get("/t3").status_code, 200)
319
                for i in range(0, 10):
320
                    timeline.forward(1)
321
                    self.assertEqual(cli.get("/t3").status_code, 429)
322
323
                for i in range(0, 10):
324
                    self.assertEqual(cli.get("/t4").status_code, 200)
325
326
                self.assertEqual(cli.get("/t5").status_code, 200)
327
                self.assertEqual(cli.get("/t5").status_code, 429)
328
329
    def test_disabled_flag(self):
330 View Code Duplication
        app, limiter = self.build_app(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
331
                config={C.ENABLED: False},
332
                default_limits=["1/minute"]
333
        )
334
335
        @app.route("/t1")
336
        def t1():
337
            return "test"
338
339
        @app.route("/t2")
340
        @limiter.limit("10 per minute")
341
        def t2():
342
            return "test"
343
344
        with app.test_client() as cli:
345
            self.assertEqual(cli.get("/t1").status_code, 200)
346
            self.assertEqual(cli.get("/t1").status_code, 200)
347
            for i in range(0, 10):
348
                self.assertEqual(cli.get("/t2").status_code, 200)
349
            self.assertEqual(cli.get("/t2").status_code, 200)
350
351
    def test_fallback_to_memory_config(self):
352
        _, limiter = self.build_app(
353
                config={C.ENABLED: True},
354
                default_limits=["5/minute"],
355
                storage_uri="redis://localhost:6379",
356
                in_memory_fallback=["1/minute"]
357
        )
358
        self.assertEqual(len(limiter._in_memory_fallback), 1)
359
360
        _, limiter = self.build_app(
361
                config={C.ENABLED: True, C.IN_MEMORY_FALLBACK: "1/minute"},
362
                default_limits=["5/minute"],
363
                storage_uri="redis://localhost:6379",
364
        )
365
        self.assertEqual(len(limiter._in_memory_fallback), 1)
366
367
    def test_fallback_to_memory_backoff_check(self):
368
        app, limiter = self.build_app(
369
                config={C.ENABLED: True},
370
                default_limits=["5/minute"],
371
                storage_uri="redis://localhost:6379",
372
                in_memory_fallback=["1/minute"]
373
        )
374
375
        @app.route("/t1")
376
        def t1():
377
            return "test"
378
379
        with app.test_client() as cli:
380
            def raiser(*a):
381
                raise Exception("redis dead")
382
383
            with hiro.Timeline() as timeline:
384
                with mock.patch(
385
                        "redis.client.Redis.execute_command"
386
                ) as exec_command:
387
                    exec_command.side_effect = raiser
388
                    self.assertEqual(cli.get("/t1").status_code, 200)
389
                    self.assertEqual(cli.get("/t1").status_code, 429)
390
                    timeline.forward(1)
391
                    self.assertEqual(cli.get("/t1").status_code, 429)
392
                    timeline.forward(2)
393
                    self.assertEqual(cli.get("/t1").status_code, 429)
394
                    timeline.forward(4)
395
                    self.assertEqual(cli.get("/t1").status_code, 429)
396
                    timeline.forward(8)
397
                    self.assertEqual(cli.get("/t1").status_code, 429)
398
                    timeline.forward(16)
399
                    self.assertEqual(cli.get("/t1").status_code, 429)
400
                    timeline.forward(32)
401
                    self.assertEqual(cli.get("/t1").status_code, 200)
402
                # redis back to normal, but exponential backoff will only
403
                # result in it being marked after pow(2,0) seconds and next
404
                # check
405
                self.assertEqual(cli.get("/t1").status_code, 429)
406
                timeline.forward(1)
407
                self.assertEqual(cli.get("/t1").status_code, 200)
408
                self.assertEqual(cli.get("/t1").status_code, 200)
409
                self.assertEqual(cli.get("/t1").status_code, 200)
410
                self.assertEqual(cli.get("/t1").status_code, 200)
411
                self.assertEqual(cli.get("/t1").status_code, 200)
412
                self.assertEqual(cli.get("/t1").status_code, 429)
413
414
    def test_fallback_to_memory(self):
415
        app, limiter = self.build_app(
416
                config={C.ENABLED: True},
417
                default_limits=["5/minute"],
418
                storage_uri="redis://localhost:6379",
419
                in_memory_fallback=["1/minute"]
420
        )
421
422
        @app.route("/t1")
423
        def t1():
424
            return "test"
425
426
        @app.route("/t2")
427
        @limiter.limit("3 per minute")
428
        def t2():
429
            return "test"
430
431
        with app.test_client() as cli:
432
            self.assertEqual(cli.get("/t1").status_code, 200)
433
            self.assertEqual(cli.get("/t1").status_code, 200)
434
            self.assertEqual(cli.get("/t1").status_code, 200)
435
            self.assertEqual(cli.get("/t1").status_code, 200)
436
            self.assertEqual(cli.get("/t1").status_code, 200)
437
            self.assertEqual(cli.get("/t1").status_code, 429)
438
            self.assertEqual(cli.get("/t2").status_code, 200)
439
            self.assertEqual(cli.get("/t2").status_code, 200)
440
            self.assertEqual(cli.get("/t2").status_code, 200)
441
            self.assertEqual(cli.get("/t2").status_code, 429)
442
443
            def raiser(*a):
444
                raise Exception("redis dead")
445
446
            with mock.patch(
447
                    "redis.client.Redis.execute_command"
448
            ) as exec_command:
449
                exec_command.side_effect = raiser
450
                self.assertEqual(cli.get("/t1").status_code, 200)
451
                self.assertEqual(cli.get("/t1").status_code, 429)
452
                self.assertEqual(cli.get("/t2").status_code, 200)
453
                self.assertEqual(cli.get("/t2").status_code, 429)
454
            # redis back to normal, go back to regular limits
455
            with hiro.Timeline() as timeline:
456
                timeline.forward(1)
457
                limiter._storage.storage.flushall()
458
                self.assertEqual(cli.get("/t2").status_code, 200)
459
                self.assertEqual(cli.get("/t2").status_code, 200)
460
                self.assertEqual(cli.get("/t2").status_code, 200)
461
                self.assertEqual(cli.get("/t2").status_code, 429)
462
463
    def test_decorated_dynamic_limits(self):
464
        app, limiter = self.build_app({"X": "2 per second"}, default_limits=["1/second"])
465
466
        def request_context_limit():
467
            limits = {
468
                "127.0.0.1": "10 per minute",
469
                "127.0.0.2": "1 per minute"
470
            }
471
            remote_addr = (request.access_route and request.access_route[0]) or request.remote_addr or '127.0.0.1'
472
            limit = limits.setdefault(remote_addr, '1 per minute')
473
            return limit
474
475
        @app.route("/t1")
476
        @limiter.limit("20/day")
477
        @limiter.limit(lambda: current_app.config.get("X"))
478
        @limiter.limit(request_context_limit)
479
        def t1():
480
            return "42"
481
482
        @app.route("/t2")
483
        @limiter.limit(lambda: current_app.config.get("X"))
484
        def t2():
485
            return "42"
486
487
        R1 = {"X_FORWARDED_FOR": "127.0.0.1, 127.0.0.0"}
488
        R2 = {"X_FORWARDED_FOR": "127.0.0.2"}
489
490
        with app.test_client() as cli:
491
            with hiro.Timeline().freeze() as timeline:
492
                for i in range(0, 10):
493
                    self.assertEqual(cli.get("/t1", headers=R1).status_code, 200)
494
                    timeline.forward(1)
495
                self.assertEqual(cli.get("/t1", headers=R1).status_code, 429)
496
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 200)
497
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 429)
498
                timeline.forward(60)
499
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 200)
500
                self.assertEqual(cli.get("/t2").status_code, 200)
501
                self.assertEqual(cli.get("/t2").status_code, 200)
502
                self.assertEqual(cli.get("/t2").status_code, 429)
503
                timeline.forward(1)
504
                self.assertEqual(cli.get("/t2").status_code, 200)
505
506
    def test_invalid_decorated_dynamic_limits(self):
507 View Code Duplication
        app = Flask(__name__)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
508
        app.config.setdefault("X", "2 per sec")
509
        limiter = Limiter(app, default_limits=["1/second"], key_func=get_remote_address)
510
        mock_handler = mock.Mock()
511
        mock_handler.level = logging.INFO
512
        limiter.logger.addHandler(mock_handler)
513
514
        @app.route("/t1")
515
        @limiter.limit(lambda: current_app.config.get("X"))
516
        def t1():
517
            return "42"
518
519
        with app.test_client() as cli:
520
            with hiro.Timeline().freeze() as timeline:
521
                self.assertEqual(cli.get("/t1").status_code, 200)
522
                self.assertEqual(cli.get("/t1").status_code, 429)
523
        # 2 for invalid limit, 1 for warning.
524
        self.assertEqual(mock_handler.handle.call_count, 3)
525
        self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[0][0][0].msg)
526
        self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[1][0][0].msg)
527
        self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[2][0][0].msg)
528
529
    def test_invalid_decorated_static_limits(self):
530 View Code Duplication
        app = Flask(__name__)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
531
        limiter = Limiter(app, default_limits=["1/second"], key_func=get_remote_address)
532
        mock_handler = mock.Mock()
533
        mock_handler.level = logging.INFO
534
        limiter.logger.addHandler(mock_handler)
535
536
        @app.route("/t1")
537
        @limiter.limit("2/sec")
538
        def t1():
539
            return "42"
540
541
        with app.test_client() as cli:
542
            with hiro.Timeline().freeze() as timeline:
543
                self.assertEqual(cli.get("/t1").status_code, 200)
544
                self.assertEqual(cli.get("/t1").status_code, 429)
545
        self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg)
546
        self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)
547
548
    def test_invalid_decorated_static_limit_blueprint(self):
549 View Code Duplication
        app = Flask(__name__)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
550
        limiter = Limiter(app, default_limits=["1/second"], key_func=get_remote_address)
551
        mock_handler = mock.Mock()
552
        mock_handler.level = logging.INFO
553
        limiter.logger.addHandler(mock_handler)
554
        bp = Blueprint("bp1", __name__)
555
556
        @bp.route("/t1")
557
        def t1():
558
            return "42"
559
560
        limiter.limit("2/sec")(bp)
561
        app.register_blueprint(bp)
562
563
        with app.test_client() as cli:
564
            with hiro.Timeline().freeze() as timeline:
565
                self.assertEqual(cli.get("/t1").status_code, 200)
566
                self.assertEqual(cli.get("/t1").status_code, 429)
567
        self.assertTrue("failed to configure" in mock_handler.handle.call_args_list[0][0][0].msg)
568
        self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[1][0][0].msg)
569
570
    def test_invalid_decorated_dynamic_limits_blueprint(self):
571 View Code Duplication
        app = Flask(__name__)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
572
        app.config.setdefault("X", "2 per sec")
573
        limiter = Limiter(app, default_limits=["1/second"], key_func=get_remote_address)
574
        mock_handler = mock.Mock()
575
        mock_handler.level = logging.INFO
576
        limiter.logger.addHandler(mock_handler)
577
        bp = Blueprint("bp1", __name__)
578
579
        @bp.route("/t1")
580
        def t1():
581
            return "42"
582
583
        limiter.limit(lambda: current_app.config.get("X"))(bp)
584
        app.register_blueprint(bp)
585
586
        with app.test_client() as cli:
587
            with hiro.Timeline().freeze() as timeline:
588
                self.assertEqual(cli.get("/t1").status_code, 200)
589
                self.assertEqual(cli.get("/t1").status_code, 429)
590
        self.assertEqual(mock_handler.handle.call_count, 3)
591
        self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[0][0][0].msg)
592
        self.assertTrue("failed to load ratelimit" in mock_handler.handle.call_args_list[1][0][0].msg)
593
        self.assertTrue("exceeded at endpoint" in mock_handler.handle.call_args_list[2][0][0].msg)
594
595
    def test_multiple_apps(self):
596
        app1 = Flask(__name__)
597
        app2 = Flask(__name__)
598
599
        limiter = Limiter(default_limits=["1/second"], key_func=get_remote_address)
600
        limiter.init_app(app1)
601
        limiter.init_app(app2)
602
603
        @app1.route("/ping")
604
        def ping():
605
            return "PONG"
606
607
        @app1.route("/slowping")
608
        @limiter.limit("1/minute")
609
        def slow_ping():
610
            return "PONG"
611
612
        @app2.route("/ping")
613
        @limiter.limit("2/second")
614
        def ping_2():
615
            return "PONG"
616
617
        @app2.route("/slowping")
618
        @limiter.limit("2/minute")
619
        def slow_ping_2():
620
            return "PONG"
621
622
        with hiro.Timeline().freeze() as timeline:
623
            with app1.test_client() as cli:
624
                self.assertEqual(cli.get("/ping").status_code, 200)
625
                self.assertEqual(cli.get("/ping").status_code, 429)
626
                timeline.forward(1)
627
                self.assertEqual(cli.get("/ping").status_code, 200)
628
                self.assertEqual(cli.get("/slowping").status_code, 200)
629
                timeline.forward(59)
630
                self.assertEqual(cli.get("/slowping").status_code, 429)
631
                timeline.forward(1)
632
                self.assertEqual(cli.get("/slowping").status_code, 200)
633
            with app2.test_client() as cli:
634
                self.assertEqual(cli.get("/ping").status_code, 200)
635
                self.assertEqual(cli.get("/ping").status_code, 200)
636
                self.assertEqual(cli.get("/ping").status_code, 429)
637
                timeline.forward(1)
638
                self.assertEqual(cli.get("/ping").status_code, 200)
639
                self.assertEqual(cli.get("/slowping").status_code, 200)
640
                timeline.forward(59)
641
                self.assertEqual(cli.get("/slowping").status_code, 200)
642
                self.assertEqual(cli.get("/slowping").status_code, 429)
643
                timeline.forward(1)
644
                self.assertEqual(cli.get("/slowping").status_code, 200)
645
646
    def test_headers_no_breach(self):
647
        app = Flask(__name__)
648
        limiter = Limiter(
649
            app, default_limits=["10/minute"], headers_enabled=True,
650
            key_func=get_remote_address
651
        )
652
653
        @app.route("/t1")
654
        def t1():
655
            return "test"
656
657
        @app.route("/t2")
658
        @limiter.limit("2/second; 5 per minute; 10/hour")
659
        def t2():
660
            return "test"
661
662
        with hiro.Timeline().freeze():
663
            with app.test_client() as cli:
664
                resp = cli.get("/t1")
665
                self.assertEqual(
666
                        resp.headers.get('X-RateLimit-Limit'),
667
                        '10'
668
                )
669
                self.assertEqual(
670
                        resp.headers.get('X-RateLimit-Remaining'),
671
                        '9'
672
                )
673
                self.assertEqual(
674
                        resp.headers.get('X-RateLimit-Reset'),
675
                        str(int(time.time() + 61))
676
                )
677
                self.assertEqual(
678
                        resp.headers.get('Retry-After'),
679
                        str(60)
680
                )
681
                resp = cli.get("/t2")
682
                self.assertEqual(
683
                        resp.headers.get('X-RateLimit-Limit'),
684
                        '2'
685
                )
686
                self.assertEqual(
687
                        resp.headers.get('X-RateLimit-Remaining'),
688
                        '1'
689
                )
690
                self.assertEqual(
691
                        resp.headers.get('X-RateLimit-Reset'),
692
                        str(int(time.time() + 2))
693
                )
694
695
                self.assertEqual(
696
                    resp.headers.get('Retry-After'),
697
                    str(1)
698
                )
699
700
    def test_headers_breach(self):
701 View Code Duplication
        app = Flask(__name__)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
702
        limiter = Limiter(
703
            app, default_limits=["10/minute"],
704
            headers_enabled=True, key_func=get_remote_address
705
        )
706
707
        @app.route("/t1")
708
        @limiter.limit("2/second; 10 per minute; 20/hour")
709
        def t():
710
            return "test"
711
712
        with hiro.Timeline().freeze() as timeline:
713
            with app.test_client() as cli:
714
                for i in range(11):
715
                    resp = cli.get("/t1")
716
                    timeline.forward(1)
717
718
                self.assertEqual(
719
                        resp.headers.get('X-RateLimit-Limit'),
720
                        '10'
721
                )
722
                self.assertEqual(
723
                        resp.headers.get('X-RateLimit-Remaining'),
724
                        '0'
725
                )
726
                self.assertEqual(
727
                        resp.headers.get('X-RateLimit-Reset'),
728
                        str(int(time.time() + 50))
729
                )
730
                self.assertEqual(
731
                        resp.headers.get('Retry-After'),
732
                        str(int(50))
733
                )
734
735
    def test_retry_after(self):
736
        app = Flask(__name__)
737
        _ = Limiter(
738
            app, default_limits=["1/minute"],
739
            headers_enabled=True, key_func=get_remote_address
740
        )
741
742
        @app.route("/t1")
743
        def t():
744
            return "test"
745
746
        with hiro.Timeline().freeze() as timeline:
747
            with app.test_client() as cli:
748
                resp = cli.get("/t1")
749
                retry_after = int(resp.headers.get('Retry-After'))
750
                self.assertTrue(retry_after > 0)
751
                timeline.forward(retry_after)
752
                resp = cli.get("/t1")
753
                self.assertEqual(resp.status_code, 200)
754
755
    def test_custom_headers_from_setter(self):
756
        app = Flask(__name__)
757
        limiter = Limiter(
758
            app, default_limits=["10/minute"], headers_enabled=True,
759
            key_func=get_remote_address, retry_after='http-date'
760
        )
761
        limiter._header_mapping[HEADERS.RESET] = 'X-Reset'
762
        limiter._header_mapping[HEADERS.LIMIT] = 'X-Limit'
763
        limiter._header_mapping[HEADERS.REMAINING] = 'X-Remaining'
764
765
        @app.route("/t1")
766
        @limiter.limit("2/second; 10 per minute; 20/hour")
767
        def t():
768
            return "test"
769
770
        with hiro.Timeline().freeze(0) as timeline:
771
            with app.test_client() as cli:
772
                for i in range(11):
773
                    resp = cli.get("/t1")
774 View Code Duplication
                    timeline.forward(1)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
775
776
                self.assertEqual(
777
                        resp.headers.get('X-Limit'),
778
                        '10'
779
                )
780
                self.assertEqual(
781
                        resp.headers.get('X-Remaining'),
782
                        '0'
783
                )
784
                self.assertEqual(
785
                        resp.headers.get('X-Reset'),
786
                        str(int(time.time() + 50))
787
                )
788
                self.assertEqual(
789
                    resp.headers.get('Retry-After'),
790
                    'Thu, 01 Jan 1970 00:01:01 GMT'
791
                )
792
793
    def test_custom_headers_from_config(self):
794
        app = Flask(__name__)
795
        app.config.setdefault(C.HEADER_LIMIT, "X-Limit")
796
        app.config.setdefault(C.HEADER_REMAINING, "X-Remaining")
797
        app.config.setdefault(C.HEADER_RESET, "X-Reset")
798
        limiter = Limiter(
799
            app, default_limits=["10/minute"], headers_enabled=True,
800
            key_func=get_remote_address
801
        )
802
803
804
        @app.route("/t1")
805
        @limiter.limit("2/second; 10 per minute; 20/hour")
806
        def t():
807
            return "test"
808
809
        with hiro.Timeline().freeze() as timeline:
810
            with app.test_client() as cli:
811
                for i in range(11):
812
                    resp = cli.get("/t1")
813
                    timeline.forward(1)
814
815
                self.assertEqual(
816
                        resp.headers.get('X-Limit'),
817
                        '10'
818
                )
819
                self.assertEqual(
820
                        resp.headers.get('X-Remaining'),
821
                        '0'
822
                )
823
                self.assertEqual(
824
                        resp.headers.get('X-Reset'),
825
                        str(int(time.time() + 50))
826
                )
827
828
    def test_named_shared_limit(self):
829
        app, limiter = self.build_app()
830
        shared_limit_a = limiter.shared_limit("1/minute", scope='a')
831
        shared_limit_b = limiter.shared_limit("1/minute", scope='b')
832
833
        @app.route("/t1")
834
        @shared_limit_a
835
        def route1():
836
            return "route1"
837
838
        @app.route("/t2")
839
        @shared_limit_a
840
        def route2():
841
            return "route2"
842
843
        @app.route("/t3")
844
        @shared_limit_b
845
        def route3():
846
            return "route3"
847
848
        with hiro.Timeline().freeze() as timeline:
849
            with app.test_client() as cli:
850
                self.assertEqual(200, cli.get("/t1").status_code)
851
                self.assertEqual(200, cli.get("/t3").status_code)
852
                self.assertEqual(429, cli.get("/t2").status_code)
853
854
    def test_application_shared_limit(self):
855 View Code Duplication
        app, limiter = self.build_app(application_limits=["2/minute"])
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
856
        @app.route("/t1")
857
        def t1():
858
            return "route1"
859
860
        @app.route("/t2")
861
        def t2():
862
            return "route2"
863
864
        with hiro.Timeline().freeze():
865
            with app.test_client() as cli:
866
                self.assertEqual(200, cli.get("/t1").status_code)
867
                self.assertEqual(200, cli.get("/t2").status_code)
868
                self.assertEqual(429, cli.get("/t1").status_code)
869
870
    def test_dynamic_shared_limit(self):
871 View Code Duplication
        app, limiter = self.build_app()
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
872
        fn_a = mock.Mock()
873
        fn_b = mock.Mock()
874
        fn_a.return_value = "foo"
875
        fn_b.return_value = "bar"
876
877
        dy_limit_a = limiter.shared_limit("1/minute", scope=fn_a)
878
        dy_limit_b = limiter.shared_limit("1/minute", scope=fn_b)
879
880
        @app.route("/t1")
881
        @dy_limit_a
882
        def t1():
883
            return "route1"
884
885
        @app.route("/t2")
886
        @dy_limit_a
887
        def t2():
888
            return "route2"
889
890
        @app.route("/t3")
891
        @dy_limit_b
892
        def t3():
893
            return "route3"
894
895
        with hiro.Timeline().freeze():
896
            with app.test_client() as cli:
897
                self.assertEqual(200, cli.get("/t1").status_code)
898
                self.assertEqual(200, cli.get("/t3").status_code)
899
                self.assertEqual(429, cli.get("/t2").status_code)
900
                self.assertEqual(429, cli.get("/t3").status_code)
901
                self.assertEqual(2, fn_a.call_count)
902
                self.assertEqual(2, fn_b.call_count)
903
                fn_b.assert_called_with("t3")
904
                fn_a.assert_has_calls([mock.call("t1"), mock.call("t2")])
905
906
    def test_conditional_limits(self):
907 View Code Duplication
        """Test that the conditional activation of the limits work."""
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
908
        app = Flask(__name__)
909
        limiter = Limiter(app, key_func=get_remote_address)
910
911
        @app.route("/limited")
912
        @limiter.limit("1 per day")
913
        def limited_route():
914
            return "passed"
915
916
        @app.route("/unlimited")
917
        @limiter.limit("1 per day", exempt_when=lambda: True)
918
        def never_limited_route():
919
            return "should always pass"
920
921
        is_exempt = False
922
923
        @app.route("/conditional")
924
        @limiter.limit("1 per day", exempt_when=lambda: is_exempt)
925
        def conditionally_limited_route():
926
            return "conditional"
927
928
        with app.test_client() as cli:
929
            self.assertEqual(cli.get("/limited").status_code, 200)
930
            self.assertEqual(cli.get("/limited").status_code, 429)
931
932
            self.assertEqual(cli.get("/unlimited").status_code, 200)
933
            self.assertEqual(cli.get("/unlimited").status_code, 200)
934
935
            self.assertEqual(cli.get("/conditional").status_code, 200)
936
            self.assertEqual(cli.get("/conditional").status_code, 429)
937
            is_exempt = True
938
            self.assertEqual(cli.get("/conditional").status_code, 200)
939
            is_exempt = False
940
            self.assertEqual(cli.get("/conditional").status_code, 429)
941
942
    def test_conditional_shared_limits(self):
943
        """Test that conditional shared limits work."""
944
        app = Flask(__name__)
945
        limiter = Limiter(app, key_func=get_remote_address)
946
947
        @app.route("/limited")
948
        @limiter.shared_limit("1 per day", "test_scope")
949
        def limited_route():
950
            return "passed"
951
952
        @app.route("/unlimited")
953
        @limiter.shared_limit("1 per day", "test_scope", exempt_when=lambda: True)
954
        def never_limited_route():
955
            return "should always pass"
956
957
        is_exempt = False
958
959
        @app.route("/conditional")
960
        @limiter.shared_limit("1 per day", "test_scope", exempt_when=lambda: is_exempt)
961
        def conditionally_limited_route():
962
            return "conditional"
963
964
        with app.test_client() as cli:
965
            self.assertEqual(cli.get("/unlimited").status_code, 200)
966
            self.assertEqual(cli.get("/unlimited").status_code, 200)
967
968
            self.assertEqual(cli.get("/limited").status_code, 200)
969
            self.assertEqual(cli.get("/limited").status_code, 429)
970
971
            self.assertEqual(cli.get("/conditional").status_code, 429)
972
            is_exempt = True
973
            self.assertEqual(cli.get("/conditional").status_code, 200)
974
            is_exempt = False
975
            self.assertEqual(cli.get("/conditional").status_code, 429)
976
977
    def test_whitelisting(self):
978
979
        app = Flask(__name__)
980
        limiter = Limiter(
981
            app, default_limits=["1/minute"], headers_enabled=True,
982
            key_func=get_remote_address
983
        )
984
985
        @app.route("/")
986
        def t():
987
            return "test"
988
989
        @limiter.request_filter
990
        def w():
991
            if request.headers.get("internal", None) == "true":
992
                return True
993
            return False
994
995
        with hiro.Timeline().freeze() as timeline:
996
            with app.test_client() as cli:
997
                self.assertEqual(cli.get("/").status_code, 200)
998
                self.assertEqual(cli.get("/").status_code, 429)
999
                timeline.forward(60)
1000
                self.assertEqual(cli.get("/").status_code, 200)
1001
1002
                for i in range(0, 10):
1003
                    self.assertEqual(
1004
                            cli.get("/", headers={"internal": "true"}).status_code,
1005
                            200
1006
                    )
1007
1008
    def test_pluggable_views(self):
1009
        app, limiter = self.build_app(
1010
                default_limits=["1/hour"]
1011
        )
1012
1013
        class Va(View):
1014
            methods = ['GET', 'POST']
1015
            decorators = [limiter.limit("2/second")]
1016
1017
            def dispatch_request(self):
1018
                return request.method.lower()
1019
1020
        class Vb(View):
1021
            methods = ['GET']
1022
            decorators = [limiter.limit("1/second, 3/minute")]
1023
1024
            def dispatch_request(self):
1025
                return request.method.lower()
1026
1027
        class Vc(View):
1028
            methods = ['GET']
1029
1030
            def dispatch_request(self):
1031
                return request.method.lower()
1032
1033
        app.add_url_rule("/a", view_func=Va.as_view("a"))
1034
        app.add_url_rule("/b", view_func=Vb.as_view("b"))
1035
        app.add_url_rule("/c", view_func=Vc.as_view("c"))
1036
        with hiro.Timeline().freeze() as timeline:
1037
            with app.test_client() as cli:
1038
                self.assertEqual(200, cli.get("/a").status_code)
1039
                self.assertEqual(200, cli.get("/a").status_code)
1040
                self.assertEqual(429, cli.post("/a").status_code)
1041
                self.assertEqual(200, cli.get("/b").status_code)
1042
                timeline.forward(1)
1043
                self.assertEqual(200, cli.get("/b").status_code)
1044
                timeline.forward(1)
1045
                self.assertEqual(200, cli.get("/b").status_code)
1046
                timeline.forward(1)
1047
                self.assertEqual(429, cli.get("/b").status_code)
1048
                self.assertEqual(200, cli.get("/c").status_code)
1049
                self.assertEqual(429, cli.get("/c").status_code)
1050
1051
    def test_pluggable_method_views(self):
1052
        app, limiter = self.build_app(
1053
                default_limits=["1/hour"]
1054
        )
1055
1056
        class Va(MethodView):
1057
            decorators = [limiter.limit("2/second")]
1058
1059
            def get(self):
1060
                return request.method.lower()
1061
1062
            def post(self):
1063
                return request.method.lower()
1064
1065
        class Vb(MethodView):
1066
            decorators = [limiter.limit("1/second, 3/minute")]
1067
1068
            def get(self):
1069
                return request.method.lower()
1070
1071
        class Vc(MethodView):
1072
            def get(self):
1073
                return request.method.lower()
1074
1075
        class Vd(MethodView):
1076
            decorators = [limiter.limit("1/minute", methods=['get'])]
1077
1078
            def get(self):
1079
                return request.method.lower()
1080
1081
            def post(self):
1082
                return request.method.lower()
1083
1084
        app.add_url_rule("/a", view_func=Va.as_view("a"))
1085
        app.add_url_rule("/b", view_func=Vb.as_view("b"))
1086
        app.add_url_rule("/c", view_func=Vc.as_view("c"))
1087
        app.add_url_rule("/d", view_func=Vd.as_view("d"))
1088
1089
        with hiro.Timeline().freeze() as timeline:
1090
            with app.test_client() as cli:
1091
                self.assertEqual(200, cli.get("/a").status_code)
1092
                self.assertEqual(200, cli.get("/a").status_code)
1093
                self.assertEqual(429, cli.get("/a").status_code)
1094
                self.assertEqual(429, cli.post("/a").status_code)
1095
                self.assertEqual(200, cli.get("/b").status_code)
1096
                timeline.forward(1)
1097
                self.assertEqual(200, cli.get("/b").status_code)
1098
                timeline.forward(1)
1099
                self.assertEqual(200, cli.get("/b").status_code)
1100
                timeline.forward(1)
1101
                self.assertEqual(429, cli.get("/b").status_code)
1102
                self.assertEqual(200, cli.get("/c").status_code)
1103
                self.assertEqual(429, cli.get("/c").status_code)
1104
                self.assertEqual(200, cli.get("/d").status_code)
1105
                self.assertEqual(429, cli.get("/d").status_code)
1106
                self.assertEqual(200, cli.post("/d").status_code)
1107
                self.assertEqual(200, cli.post("/d").status_code)
1108
1109
    def test_flask_restful_resource(self):
1110
        app, limiter = self.build_app(
1111
                default_limits=["1/hour"]
1112
        )
1113
        api = restful.Api(app)
1114
1115
        class Va(Resource):
1116
            decorators = [limiter.limit("2/second")]
1117
1118
            def get(self):
1119
                return request.method.lower()
1120
1121
            def post(self):
1122
                return request.method.lower()
1123
1124
        class Vb(Resource):
1125
            decorators = [limiter.limit("1/second, 3/minute")]
1126
1127
            def get(self):
1128
                return request.method.lower()
1129
1130
        class Vc(Resource):
1131
            def get(self):
1132
                return request.method.lower()
1133
1134
        api.add_resource(Va, "/a")
1135 View Code Duplication
        api.add_resource(Vb, "/b")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1136
        api.add_resource(Vc, "/c")
1137
1138
        with hiro.Timeline().freeze() as timeline:
1139
            with app.test_client() as cli:
1140
                self.assertEqual(200, cli.get("/a").status_code)
1141
                self.assertEqual(200, cli.get("/a").status_code)
1142
                self.assertEqual(429, cli.get("/a").status_code)
1143
                self.assertEqual(429, cli.post("/a").status_code)
1144
                self.assertEqual(200, cli.get("/b").status_code)
1145
                timeline.forward(1)
1146
                self.assertEqual(200, cli.get("/b").status_code)
1147
                timeline.forward(1)
1148
                self.assertEqual(200, cli.get("/b").status_code)
1149
                timeline.forward(1)
1150 View Code Duplication
                self.assertEqual(429, cli.get("/b").status_code)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1151
                self.assertEqual(200, cli.get("/c").status_code)
1152
                self.assertEqual(429, cli.get("/c").status_code)
1153
1154
    def test_separate_method_limits(self):
1155
        app, limiter = self.build_app()
1156
1157
        @limiter.limit("1/second", per_method=True)
1158
        @app.route("/", methods=["GET", "POST"])
1159
        def root():
1160
            return "root"
1161
1162
        with hiro.Timeline():
1163
            with app.test_client() as cli:
1164
                self.assertEqual(200, cli.get("/").status_code)
1165
                self.assertEqual(429, cli.get("/").status_code)
1166
                self.assertEqual(200, cli.post("/").status_code)
1167
                self.assertEqual(429, cli.post("/").status_code)
1168
1169
    def test_explicit_method_limits(self):
1170
        app, limiter = self.build_app()
1171
1172
        @limiter.limit("1/second", methods=["GET"])
1173
        @app.route("/", methods=["GET", "POST"])
1174
        def root():
1175
            return "root"
1176
1177
        with hiro.Timeline():
1178
            with app.test_client() as cli:
1179
                self.assertEqual(200, cli.get("/").status_code)
1180
                self.assertEqual(429, cli.get("/").status_code)
1181
                self.assertEqual(200, cli.post("/").status_code)
1182
                self.assertEqual(200, cli.post("/").status_code)
1183
1184
    def test_no_auto_check(self):
1185
        app, limiter = self.build_app(auto_check=False)
1186
1187
        @limiter.limit("1/second", per_method=True)
1188
        @app.route("/", methods=["GET", "POST"])
1189
        def root():
1190
            return "root"
1191
1192
        with hiro.Timeline().freeze():
1193
            with app.test_client() as cli:
1194
                self.assertEqual(200, cli.get("/").status_code)
1195
                self.assertEqual(200, cli.get("/").status_code)
1196
1197
        # attach before_request to perform check
1198
        @app.before_request
1199
        def _():
1200
            limiter.check()
1201
1202
        with hiro.Timeline().freeze():
1203
            with app.test_client() as cli:
1204
                self.assertEqual(200, cli.get("/").status_code)
1205
                self.assertEqual(429, cli.get("/").status_code)
1206
1207
    def test_custom_error_message(self):
1208
        app, limiter = self.build_app()
1209
1210
        @app.errorhandler(429)
1211
        def ratelimit_handler(e):
1212
            return make_response(
1213
                    e.description
1214
                    , 429
1215
            )
1216
1217
        l1 = lambda: "1/second"
1218
        e1 = lambda: "dos"
1219
1220
        @limiter.limit("1/second", error_message="uno")
1221
        @app.route("/t1")
1222
        def t1():
1223
            return "1"
1224
1225
        @limiter.limit(l1, error_message=e1)
1226
        @app.route("/t2")
1227
        def t2():
1228
            return "2"
1229
1230
        s1 = limiter.shared_limit("1/second", scope='error_message', error_message="tres")
1231
1232
        @app.route("/t3")
1233
        @s1
1234
        def t3():
1235
            return "3"
1236
1237
        with hiro.Timeline().freeze():
1238
            with app.test_client() as cli:
1239
                cli.get("/t1")
1240
                resp = cli.get("/t1")
1241
                self.assertEqual(429, resp.status_code)
1242
                self.assertEqual(resp.data, b'uno')
1243
                cli.get("/t2")
1244
                resp = cli.get("/t2")
1245
                self.assertEqual(429, resp.status_code)
1246
                self.assertEqual(resp.data, b'dos')
1247
                cli.get("/t3")
1248
                resp = cli.get("/t3")
1249
                self.assertEqual(429, resp.status_code)
1250
                self.assertEqual(resp.data, b'tres')
1251
1252
    def test_custom_key_prefix(self):
1253
        app1, limiter1 = self.build_app(key_prefix="moo", storage_uri="redis://localhost:6379")
1254
        app2, limiter2 = self.build_app({C.KEY_PREFIX: "cow"}, storage_uri="redis://localhost:6379")
1255
        app3, limiter3 = self.build_app(storage_uri="redis://localhost:6379")
1256
1257
        @app1.route("/test")
1258
        @limiter1.limit("1/day")
1259
        def t1():
1260
            return "app1 test"
1261
1262
        @app2.route("/test")
1263
        @limiter2.limit("1/day")
1264
        def t1():
1265
            return "app1 test"
1266
1267
        @app3.route("/test")
1268
        @limiter3.limit("1/day")
1269
        def t1():
1270
            return "app1 test"
1271
1272
        with app1.test_client() as cli:
1273
            resp = cli.get("/test")
1274
            self.assertEqual(200, resp.status_code)
1275
            resp = cli.get("/test")
1276
            self.assertEqual(429, resp.status_code)
1277
        with app2.test_client() as cli:
1278
            resp = cli.get("/test")
1279
            self.assertEqual(200, resp.status_code)
1280
            resp = cli.get("/test")
1281
            self.assertEqual(429, resp.status_code)
1282
        with app3.test_client() as cli:
1283
            resp = cli.get("/test")
1284
            self.assertEqual(200, resp.status_code)
1285
            resp = cli.get("/test")
1286
            self.assertEqual(429, resp.status_code)
1287