Completed
Push — master ( c61a7a...08e939 )
by Ali-Akber
28s
created

ErrorHandlingTests.t2()   A

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
c 0
b 0
f 0
dl 0
loc 4
rs 10
1
"""
2
3
"""
4
import json
5
import logging
6
import time
7
import unittest
8
9
import hiro
10
import mock
11
import redis
12
import datetime
13
from flask import Flask, Blueprint, request, current_app, make_response
14
from flask_restful import Resource, Api as RestfulApi
15
from flask.views import View, MethodView
16
from limits.errors import ConfigurationError
17
from limits.storage import MemcachedStorage
18
from limits.strategies import MovingWindowRateLimiter
19
20
from flask_limiter.extension import C, Limiter, HEADERS
21
from flask_limiter.util import get_remote_address, get_ipaddr
22
from tests import FlaskLimiterTestCase
23
24
25
class ConfigurationTests(FlaskLimiterTestCase):
26
    def test_invalid_strategy(self):
27
        app = Flask(__name__)
28
        app.config.setdefault(C.STRATEGY, "fubar")
29
        self.assertRaises(
30
            ConfigurationError, Limiter, app, key_func=get_remote_address
31
        )
32
33
    def test_invalid_storage_string(self):
34
        app = Flask(__name__)
35
        app.config.setdefault(C.STORAGE_URL, "fubar://localhost:1234")
36
        self.assertRaises(
37
            ConfigurationError, Limiter, app, key_func=get_remote_address
38
        )
39
40
    def test_constructor_arguments_over_config(self):
41
        app = Flask(__name__)
42
        app.config.setdefault(C.STRATEGY, "fixed-window-elastic-expiry")
43
        limiter = Limiter(
44
            strategy='moving-window', key_func=get_remote_address
45
        )
46
        limiter.init_app(app)
47
        app.config.setdefault(C.STORAGE_URL, "redis://localhost:6379")
48
        self.assertEqual(type(limiter._limiter), MovingWindowRateLimiter)
49
        limiter = Limiter(
50
            storage_uri='memcached://localhost:11211',
51
            key_func=get_remote_address
52
        )
53
        limiter.init_app(app)
54
        self.assertEqual(type(limiter._storage), MemcachedStorage)
55
56 View Code Duplication
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
57
class ErrorHandlingTests(FlaskLimiterTestCase):
58
    def test_error_message(self):
59
        app, limiter = self.build_app({C.GLOBAL_LIMITS: "1 per day"})
60
61
        @app.route("/")
62
        def null():
63
            return ""
64
65
        with app.test_client() as cli:
66
            cli.get("/")
67
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
68
69
            @app.errorhandler(429)
70
            def ratelimit_handler(e):
71
                return make_response(
72
                    '{"error" : "rate limit %s"}' % str(e.description), 429
73
                )
74
75
            self.assertEqual({
76
                'error': 'rate limit 1 per 1 day'
77
            }, json.loads(cli.get("/").data.decode()))
78
79
    def test_custom_error_message(self):
80
        app, limiter = self.build_app()
81
82
        @app.errorhandler(429)
83
        def ratelimit_handler(e):
84
            return make_response(e.description, 429)
85
86
        l1 = lambda: "1/second"
87
        e1 = lambda: "dos"
88
89
        @limiter.limit("1/second", error_message="uno")
90
        @app.route("/t1")
91
        def t1():
92
            return "1"
93
94
        @limiter.limit(l1, error_message=e1)
95
        @app.route("/t2")
96
        def t2():
97
            return "2"
98
99
        s1 = limiter.shared_limit(
100
            "1/second", scope='error_message', error_message="tres"
101
        )
102
103
        @app.route("/t3")
104
        @s1
105
        def t3():
106
            return "3"
107
108
        with hiro.Timeline().freeze():
109
            with app.test_client() as cli:
110
                cli.get("/t1")
111
                resp = cli.get("/t1")
112 View Code Duplication
                self.assertEqual(429, resp.status_code)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
113
                self.assertEqual(resp.data, b'uno')
114
                cli.get("/t2")
115
                resp = cli.get("/t2")
116
                self.assertEqual(429, resp.status_code)
117
                self.assertEqual(resp.data, b'dos')
118
                cli.get("/t3")
119
                resp = cli.get("/t3")
120
                self.assertEqual(429, resp.status_code)
121
                self.assertEqual(resp.data, b'tres')
122
123
    def test_swallow_error(self):
124
        app, limiter = self.build_app({
125
            C.GLOBAL_LIMITS: "1 per day",
126
            C.SWALLOW_ERRORS: True
127
        })
128
129
        @app.route("/")
130
        def null():
131
            return "ok"
132
133
        with app.test_client() as cli:
134
            with mock.patch(
135
                "limits.strategies.FixedWindowRateLimiter.hit"
136 View Code Duplication
            ) as hit:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
137
138
                def raiser(*a, **k):
139
                    raise Exception
140
141
                hit.side_effect = raiser
142
                self.assertTrue("ok" in cli.get("/").data.decode())
143
144
    def test_no_swallow_error(self):
145
        app, limiter = self.build_app({
146
            C.GLOBAL_LIMITS: "1 per day",
147
        })
148
149
        @app.route("/")
150
        def null():
151
            return "ok"
152
153
        @app.errorhandler(500)
154
        def e500(e):
155
            return str(e), 500
156
157
        with app.test_client() as cli:
158
            with mock.patch(
159
                "limits.strategies.FixedWindowRateLimiter.hit"
160
            ) as hit:
161
162
                def raiser(*a, **k):
163
                    raise Exception("underlying")
164
165
                hit.side_effect = raiser
166
                self.assertEqual(500, cli.get("/").status_code)
167
                self.assertEqual("underlying", cli.get("/").data.decode())
168
169
    def test_fallback_to_memory_config(self):
170
        _, limiter = self.build_app(
171
            config={C.ENABLED: True},
172
            default_limits=["5/minute"],
173
            storage_uri="redis://localhost:6379",
174
            in_memory_fallback=["1/minute"]
175
        )
176
        self.assertEqual(len(limiter._in_memory_fallback), 1)
177
178
        _, limiter = self.build_app(
179
            config={C.ENABLED: True,
180
                    C.IN_MEMORY_FALLBACK: "1/minute"},
181
            default_limits=["5/minute"],
182
            storage_uri="redis://localhost:6379",
183
        )
184
        self.assertEqual(len(limiter._in_memory_fallback), 1)
185
186
    def test_fallback_to_memory_backoff_check(self):
187
        app, limiter = self.build_app(
188
            config={C.ENABLED: True},
189
            default_limits=["5/minute"],
190
            storage_uri="redis://localhost:6379",
191
            in_memory_fallback=["1/minute"]
192
        )
193
194
        @app.route("/t1")
195
        def t1():
196
            return "test"
197
198
        with app.test_client() as cli:
199
200
            def raiser(*a):
201
                raise Exception("redis dead")
202
203
            with hiro.Timeline() as timeline:
204
                with mock.patch(
205
                    "redis.client.Redis.execute_command"
206
                ) as exec_command:
207
                    exec_command.side_effect = raiser
208
                    self.assertEqual(cli.get("/t1").status_code, 200)
209
                    self.assertEqual(cli.get("/t1").status_code, 429)
210 View Code Duplication
                    timeline.forward(1)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
211
                    self.assertEqual(cli.get("/t1").status_code, 429)
212
                    timeline.forward(2)
213
                    self.assertEqual(cli.get("/t1").status_code, 429)
214
                    timeline.forward(4)
215
                    self.assertEqual(cli.get("/t1").status_code, 429)
216
                    timeline.forward(8)
217
                    self.assertEqual(cli.get("/t1").status_code, 429)
218
                    timeline.forward(16)
219
                    self.assertEqual(cli.get("/t1").status_code, 429)
220
                    timeline.forward(32)
221
                    self.assertEqual(cli.get("/t1").status_code, 200)
222
                # redis back to normal, but exponential backoff will only
223
                # result in it being marked after pow(2,0) seconds and next
224
                # check
225
                self.assertEqual(cli.get("/t1").status_code, 429)
226
                timeline.forward(2)
227 View Code Duplication
                self.assertEqual(cli.get("/t1").status_code, 200)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
228
                self.assertEqual(cli.get("/t1").status_code, 200)
229
                self.assertEqual(cli.get("/t1").status_code, 200)
230
                self.assertEqual(cli.get("/t1").status_code, 200)
231
                self.assertEqual(cli.get("/t1").status_code, 200)
232
                self.assertEqual(cli.get("/t1").status_code, 429)
233
234
    def test_fallback_to_memory(self):
235
        app, limiter = self.build_app(
236
            config={C.ENABLED: True},
237
            default_limits=["5/minute"],
238
            storage_uri="redis://localhost:6379",
239
            in_memory_fallback=["1/minute"]
240
        )
241
242
        @app.route("/t1")
243
        def t1():
244
            return "test"
245
246
        @app.route("/t2")
247
        @limiter.limit("3 per minute")
248
        def t2():
249
            return "test"
250
251
        with app.test_client() as cli:
252
            self.assertEqual(cli.get("/t1").status_code, 200)
253
            self.assertEqual(cli.get("/t1").status_code, 200)
254
            self.assertEqual(cli.get("/t1").status_code, 200)
255
            self.assertEqual(cli.get("/t1").status_code, 200)
256
            self.assertEqual(cli.get("/t1").status_code, 200)
257
            self.assertEqual(cli.get("/t1").status_code, 429)
258
            self.assertEqual(cli.get("/t2").status_code, 200)
259
            self.assertEqual(cli.get("/t2").status_code, 200)
260
            self.assertEqual(cli.get("/t2").status_code, 200)
261
            self.assertEqual(cli.get("/t2").status_code, 429)
262
263
            def raiser(*a):
264
                raise Exception("redis dead")
265
266
            with mock.patch(
267
                "redis.client.Redis.execute_command"
268
            ) as exec_command:
269
                exec_command.side_effect = raiser
270
                self.assertEqual(cli.get("/t1").status_code, 200)
271
                self.assertEqual(cli.get("/t1").status_code, 429)
272
                self.assertEqual(cli.get("/t2").status_code, 200)
273
                self.assertEqual(cli.get("/t2").status_code, 429)
274
            # redis back to normal, go back to regular limits
275
            with hiro.Timeline() as timeline:
276
                timeline.forward(2)
277
                limiter._storage.storage.flushall()
278
                self.assertEqual(cli.get("/t2").status_code, 200)
279
                self.assertEqual(cli.get("/t2").status_code, 200)
280
                self.assertEqual(cli.get("/t2").status_code, 200)
281
                self.assertEqual(cli.get("/t2").status_code, 429)
282
283
284
class DecoratorTests(FlaskLimiterTestCase):
285
    def test_multiple_decorators(self):
286
        app, limiter = self.build_app(key_func=get_ipaddr)
287
288
        @app.route("/t1")
289
        @limiter.limit(
290
            "100 per minute", lambda: "test"
291
        )  # effectively becomes a limit for all users
292
        @limiter.limit("50/minute")  # per ip as per default key_func
293
        def t1():
294
            return "test"
295
296
        with hiro.Timeline().freeze() as timeline:
297
            with app.test_client() as cli:
298
                for i in range(0, 100):
299
                    self.assertEqual(
300
                        200 if i < 50 else 429,
301
                        cli.get(
302
                            "/t1", headers={
303
                                "X_FORWARDED_FOR": "127.0.0.2"
304
                            }
305
                        ).status_code
306
                    )
307
                for i in range(50):
308
                    self.assertEqual(200, cli.get("/t1").status_code)
309
                self.assertEqual(429, cli.get("/t1").status_code)
310
                self.assertEqual(
311
                    429,
312
                    cli.get("/t1", headers={
313
                        "X_FORWARDED_FOR": "127.0.0.3"
314
                    }).status_code
315
                )
316
317 View Code Duplication
    def test_exempt_routes(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
318
        app, limiter = self.build_app(default_limits=["1/minute"])
319
320
        @app.route("/t1")
321
        def t1():
322
            return "test"
323
324
        @app.route("/t2")
325
        @limiter.exempt
326
        def t2():
327
            return "test"
328
329
        with app.test_client() as cli:
330
            self.assertEqual(cli.get("/t1").status_code, 200)
331
            self.assertEqual(cli.get("/t1").status_code, 429)
332
            self.assertEqual(cli.get("/t2").status_code, 200)
333
            self.assertEqual(cli.get("/t2").status_code, 200)
334
335
    def test_decorated_dynamic_limits(self):
336
        app, limiter = self.build_app({
337
            "X": "2 per second"
338
        },
339
                                      default_limits=["1/second"])
340
341
        def request_context_limit():
342
            limits = {
343
                "127.0.0.1": "10 per minute",
344
                "127.0.0.2": "1 per minute"
345
            }
346
            remote_addr = (request.access_route and request.access_route[0]
347
                           ) or request.remote_addr or '127.0.0.1'
348
            limit = limits.setdefault(remote_addr, '1 per minute')
349
            return limit
350
351
        @app.route("/t1")
352
        @limiter.limit("20/day")
353
        @limiter.limit(lambda: current_app.config.get("X"))
354
        @limiter.limit(request_context_limit)
355
        def t1():
356
            return "42"
357
358
        @app.route("/t2")
359
        @limiter.limit(lambda: current_app.config.get("X"))
360
        def t2():
361
            return "42"
362
363
        R1 = {"X_FORWARDED_FOR": "127.0.0.1, 127.0.0.0"}
364
        R2 = {"X_FORWARDED_FOR": "127.0.0.2"}
365
366
        with app.test_client() as cli:
367
            with hiro.Timeline().freeze() as timeline:
368
                for i in range(0, 10):
369
                    self.assertEqual(
370
                        cli.get("/t1", headers=R1).status_code, 200
371
                    )
372
                    timeline.forward(1)
373
                self.assertEqual(cli.get("/t1", headers=R1).status_code, 429)
374
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 200)
375
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 429)
376
                timeline.forward(60)
377
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 200)
378
                self.assertEqual(cli.get("/t2").status_code, 200)
379
                self.assertEqual(cli.get("/t2").status_code, 200)
380
                self.assertEqual(cli.get("/t2").status_code, 429)
381
                timeline.forward(1)
382
                self.assertEqual(cli.get("/t2").status_code, 200)
383
384
    def test_invalid_decorated_dynamic_limits(self):
385
        app = Flask(__name__)
386
        app.config.setdefault("X", "2 per sec")
387
        limiter = Limiter(
388
            app, default_limits=["1/second"], key_func=get_remote_address
389
        )
390
        mock_handler = mock.Mock()
391
        mock_handler.level = logging.INFO
392
        limiter.logger.addHandler(mock_handler)
393
394
        @app.route("/t1")
395
        @limiter.limit(lambda: current_app.config.get("X"))
396
        def t1():
397
            return "42"
398
399
        with app.test_client() as cli:
400
            with hiro.Timeline().freeze() as timeline:
401
                self.assertEqual(cli.get("/t1").status_code, 200)
402
                self.assertEqual(cli.get("/t1").status_code, 429)
403
        # 2 for invalid limit, 1 for warning.
404
        self.assertEqual(mock_handler.handle.call_count, 3)
405
        self.assertTrue(
406
            "failed to load ratelimit" in mock_handler.handle.call_args_list[0]
407
            [0][0].msg
408
        )
409
        self.assertTrue(
410
            "failed to load ratelimit" in mock_handler.handle.call_args_list[1]
411
            [0][0].msg
412
        )
413
        self.assertTrue(
414
            "exceeded at endpoint" in mock_handler.handle.call_args_list[2][0]
415
            [0].msg
416
        )
417
418
    def test_invalid_decorated_static_limits(self):
419
        app = Flask(__name__)
420
        limiter = Limiter(
421
            app, default_limits=["1/second"], key_func=get_remote_address
422
        )
423
        mock_handler = mock.Mock()
424
        mock_handler.level = logging.INFO
425
        limiter.logger.addHandler(mock_handler)
426
427
        @app.route("/t1")
428
        @limiter.limit("2/sec")
429
        def t1():
430
            return "42"
431
432
        with app.test_client() as cli:
433
            with hiro.Timeline().freeze() as timeline:
434
                self.assertEqual(cli.get("/t1").status_code, 200)
435
                self.assertEqual(cli.get("/t1").status_code, 429)
436
        self.assertTrue(
437
            "failed to configure" in mock_handler.handle.call_args_list[0][0]
438
            [0].msg
439
        )
440
        self.assertTrue(
441
            "exceeded at endpoint" in mock_handler.handle.call_args_list[1][0]
442
            [0].msg
443
        )
444
445
    def test_named_shared_limit(self):
446
        app, limiter = self.build_app()
447
        shared_limit_a = limiter.shared_limit("1/minute", scope='a')
448
        shared_limit_b = limiter.shared_limit("1/minute", scope='b')
449
450
        @app.route("/t1")
451
        @shared_limit_a
452
        def route1():
453
            return "route1"
454
455
        @app.route("/t2")
456
        @shared_limit_a
457
        def route2():
458
            return "route2"
459
460
        @app.route("/t3")
461
        @shared_limit_b
462
        def route3():
463
            return "route3"
464
465
        with hiro.Timeline().freeze() as timeline:
466
            with app.test_client() as cli:
467
                self.assertEqual(200, cli.get("/t1").status_code)
468
                self.assertEqual(200, cli.get("/t3").status_code)
469
                self.assertEqual(429, cli.get("/t2").status_code)
470
471
    def test_dynamic_shared_limit(self):
472
        app, limiter = self.build_app()
473
        fn_a = mock.Mock()
474
        fn_b = mock.Mock()
475
        fn_a.return_value = "foo"
476
        fn_b.return_value = "bar"
477
478
        dy_limit_a = limiter.shared_limit("1/minute", scope=fn_a)
479
        dy_limit_b = limiter.shared_limit("1/minute", scope=fn_b)
480
481
        @app.route("/t1")
482
        @dy_limit_a
483
        def t1():
484
            return "route1"
485
486
        @app.route("/t2")
487
        @dy_limit_a
488
        def t2():
489
            return "route2"
490
491
        @app.route("/t3")
492
        @dy_limit_b
493
        def t3():
494
            return "route3"
495
496
        with hiro.Timeline().freeze():
497
            with app.test_client() as cli:
498
                self.assertEqual(200, cli.get("/t1").status_code)
499
                self.assertEqual(200, cli.get("/t3").status_code)
500
                self.assertEqual(429, cli.get("/t2").status_code)
501
                self.assertEqual(429, cli.get("/t3").status_code)
502
                self.assertEqual(2, fn_a.call_count)
503
                self.assertEqual(2, fn_b.call_count)
504
                fn_b.assert_called_with("t3")
505
                fn_a.assert_has_calls([mock.call("t1"), mock.call("t2")])
506
507 View Code Duplication
    def test_conditional_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
508
        """Test that the conditional activation of the limits work."""
509
        app = Flask(__name__)
510
        limiter = Limiter(app, key_func=get_remote_address)
511
512
        @app.route("/limited")
513
        @limiter.limit("1 per day")
514
        def limited_route():
515
            return "passed"
516
517
        @app.route("/unlimited")
518
        @limiter.limit("1 per day", exempt_when=lambda: True)
519
        def never_limited_route():
520
            return "should always pass"
521
522
        is_exempt = False
523
524
        @app.route("/conditional")
525
        @limiter.limit("1 per day", exempt_when=lambda: is_exempt)
526
        def conditionally_limited_route():
527
            return "conditional"
528
529
        with app.test_client() as cli:
530 View Code Duplication
            self.assertEqual(cli.get("/limited").status_code, 200)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
531
            self.assertEqual(cli.get("/limited").status_code, 429)
532
533
            self.assertEqual(cli.get("/unlimited").status_code, 200)
534
            self.assertEqual(cli.get("/unlimited").status_code, 200)
535
536
            self.assertEqual(cli.get("/conditional").status_code, 200)
537
            self.assertEqual(cli.get("/conditional").status_code, 429)
538
            is_exempt = True
539
            self.assertEqual(cli.get("/conditional").status_code, 200)
540
            is_exempt = False
541
            self.assertEqual(cli.get("/conditional").status_code, 429)
542
543
    def test_conditional_shared_limits(self):
544
        """Test that conditional shared limits work."""
545
        app = Flask(__name__)
546
        limiter = Limiter(app, key_func=get_remote_address)
547
548
        @app.route("/limited")
549 View Code Duplication
        @limiter.shared_limit("1 per day", "test_scope")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
550
        def limited_route():
551
            return "passed"
552
553
        @app.route("/unlimited")
554
        @limiter.shared_limit(
555
            "1 per day", "test_scope", exempt_when=lambda: True
556
        )
557
        def never_limited_route():
558
            return "should always pass"
559
560
        is_exempt = False
561
562
        @app.route("/conditional")
563
        @limiter.shared_limit(
564
            "1 per day", "test_scope", exempt_when=lambda: is_exempt
565
        )
566
        def conditionally_limited_route():
567
            return "conditional"
568
569
        with app.test_client() as cli:
570
            self.assertEqual(cli.get("/unlimited").status_code, 200)
571 View Code Duplication
            self.assertEqual(cli.get("/unlimited").status_code, 200)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
572
573
            self.assertEqual(cli.get("/limited").status_code, 200)
574
            self.assertEqual(cli.get("/limited").status_code, 429)
575
576
            self.assertEqual(cli.get("/conditional").status_code, 429)
577
            is_exempt = True
578
            self.assertEqual(cli.get("/conditional").status_code, 200)
579
            is_exempt = False
580
            self.assertEqual(cli.get("/conditional").status_code, 429)
581
582
    def test_whitelisting(self):
583
584
        app = Flask(__name__)
585
        limiter = Limiter(
586
            app,
587
            default_limits=["1/minute"],
588
            headers_enabled=True,
589
            key_func=get_remote_address
590
        )
591
592
        @app.route("/")
593
        def t():
594
            return "test"
595
596
        @limiter.request_filter
597
        def w():
598
            if request.headers.get("internal", None) == "true":
599
                return True
600
            return False
601
602
        with hiro.Timeline().freeze() as timeline:
603
            with app.test_client() as cli:
604
                self.assertEqual(cli.get("/").status_code, 200)
605
                self.assertEqual(cli.get("/").status_code, 429)
606
                timeline.forward(60)
607
                self.assertEqual(cli.get("/").status_code, 200)
608
609
                for i in range(0, 10):
610
                    self.assertEqual(
611
                        cli.get("/", headers={
612
                            "internal": "true"
613
                        }).status_code, 200
614
                    )
615
616 View Code Duplication
    def test_separate_method_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
617
        app, limiter = self.build_app()
618
619
        @limiter.limit("1/second", per_method=True)
620
        @app.route("/", methods=["GET", "POST"])
621
        def root():
622
            return "root"
623
624
        with hiro.Timeline():
625
            with app.test_client() as cli:
626
                self.assertEqual(200, cli.get("/").status_code)
627
                self.assertEqual(429, cli.get("/").status_code)
628
                self.assertEqual(200, cli.post("/").status_code)
629
                self.assertEqual(429, cli.post("/").status_code)
630
631 View Code Duplication
    def test_explicit_method_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
632
        app, limiter = self.build_app()
633
634
        @limiter.limit("1/second", methods=["GET"])
635
        @app.route("/", methods=["GET", "POST"])
636
        def root():
637
            return "root"
638
639
        with hiro.Timeline():
640
            with app.test_client() as cli:
641
                self.assertEqual(200, cli.get("/").status_code)
642
                self.assertEqual(429, cli.get("/").status_code)
643
                self.assertEqual(200, cli.post("/").status_code)
644
                self.assertEqual(200, cli.post("/").status_code)
645
646
647
class BlueprintTests(FlaskLimiterTestCase):
648 View Code Duplication
    def test_blueprint(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
649
        app, limiter = self.build_app(default_limits=["1/minute"])
650
        bp = Blueprint("main", __name__)
651
652
        @bp.route("/t1")
653
        def t1():
654
            return "test"
655
656
        @bp.route("/t2")
657
        @limiter.limit("10 per minute")
658
        def t2():
659
            return "test"
660
661
        app.register_blueprint(bp)
662
663
        with app.test_client() as cli:
664
            self.assertEqual(cli.get("/t1").status_code, 200)
665
            self.assertEqual(cli.get("/t1").status_code, 429)
666
            for i in range(0, 10):
667
                self.assertEqual(cli.get("/t2").status_code, 200)
668
            self.assertEqual(cli.get("/t2").status_code, 429)
669
670
    def test_register_blueprint(self):
671
        app, limiter = self.build_app(default_limits=["1/minute"])
672
        bp_1 = Blueprint("bp1", __name__)
673
        bp_2 = Blueprint("bp2", __name__)
674
        bp_3 = Blueprint("bp3", __name__)
675
        bp_4 = Blueprint("bp4", __name__)
676
677
        @bp_1.route("/t1")
678
        def t1():
679
            return "test"
680
681
        @bp_1.route("/t2")
682
        def t2():
683
            return "test"
684
685
        @bp_2.route("/t3")
686
        def t3():
687
            return "test"
688
689
        @bp_3.route("/t4")
690
        def t4():
691
            return "test"
692
693
        @bp_4.route("/t5")
694
        def t4():
695
            return "test"
696
697
        def dy_limit():
698
            return "1/second"
699
700
        app.register_blueprint(bp_1)
701 View Code Duplication
        app.register_blueprint(bp_2)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
702
        app.register_blueprint(bp_3)
703
        app.register_blueprint(bp_4)
704
705
        limiter.limit("1/second")(bp_1)
706
        limiter.exempt(bp_3)
707
        limiter.limit(dy_limit)(bp_4)
708
709
        with hiro.Timeline().freeze() as timeline:
710
            with app.test_client() as cli:
711
                self.assertEqual(cli.get("/t1").status_code, 200)
712
                self.assertEqual(cli.get("/t1").status_code, 429)
713
                timeline.forward(1)
714
                self.assertEqual(cli.get("/t1").status_code, 200)
715
                self.assertEqual(cli.get("/t2").status_code, 200)
716
                self.assertEqual(cli.get("/t2").status_code, 429)
717
                timeline.forward(1)
718
                self.assertEqual(cli.get("/t2").status_code, 200)
719
720
                self.assertEqual(cli.get("/t3").status_code, 200)
721
                for i in range(0, 10):
722
                    timeline.forward(1)
723
                    self.assertEqual(cli.get("/t3").status_code, 429)
724
725
                for i in range(0, 10):
726
                    self.assertEqual(cli.get("/t4").status_code, 200)
727
728
                self.assertEqual(cli.get("/t5").status_code, 200)
729
                self.assertEqual(cli.get("/t5").status_code, 429)
730
731
    def test_invalid_decorated_static_limit_blueprint(self):
732
        app = Flask(__name__)
733
        limiter = Limiter(
734
            app, default_limits=["1/second"], key_func=get_remote_address
735
        )
736
        mock_handler = mock.Mock()
737
        mock_handler.level = logging.INFO
738
        limiter.logger.addHandler(mock_handler)
739
        bp = Blueprint("bp1", __name__)
740
741
        @bp.route("/t1")
742
        def t1():
743
            return "42"
744
745
        limiter.limit("2/sec")(bp)
746
        app.register_blueprint(bp)
747
748
        with app.test_client() as cli:
749
            with hiro.Timeline().freeze() as timeline:
750
                self.assertEqual(cli.get("/t1").status_code, 200)
751
                self.assertEqual(cli.get("/t1").status_code, 429)
752
        self.assertTrue(
753
            "failed to configure" in mock_handler.handle.call_args_list[0][0]
754
            [0].msg
755
        )
756
        self.assertTrue(
757
            "exceeded at endpoint" in mock_handler.handle.call_args_list[1][0]
758
            [0].msg
759
        )
760
761
    def test_invalid_decorated_dynamic_limits_blueprint(self):
762
        app = Flask(__name__)
763
        app.config.setdefault("X", "2 per sec")
764
        limiter = Limiter(
765
            app, default_limits=["1/second"], key_func=get_remote_address
766
        )
767
        mock_handler = mock.Mock()
768
        mock_handler.level = logging.INFO
769
        limiter.logger.addHandler(mock_handler)
770
        bp = Blueprint("bp1", __name__)
771
772
        @bp.route("/t1")
773
        def t1():
774 View Code Duplication
            return "42"
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
775
776
        limiter.limit(lambda: current_app.config.get("X"))(bp)
777
        app.register_blueprint(bp)
778
779
        with app.test_client() as cli:
780
            with hiro.Timeline().freeze() as timeline:
781
                self.assertEqual(cli.get("/t1").status_code, 200)
782
                self.assertEqual(cli.get("/t1").status_code, 429)
783
        self.assertEqual(mock_handler.handle.call_count, 3)
784
        self.assertTrue(
785
            "failed to load ratelimit" in mock_handler.handle.call_args_list[0]
786
            [0][0].msg
787
        )
788
        self.assertTrue(
789
            "failed to load ratelimit" in mock_handler.handle.call_args_list[1]
790
            [0][0].msg
791
        )
792
        self.assertTrue(
793
            "exceeded at endpoint" in mock_handler.handle.call_args_list[2][0]
794
            [0].msg
795
        )
796
797
class ViewsTests(FlaskLimiterTestCase):
798
    def test_pluggable_views(self):
799
        app, limiter = self.build_app(default_limits=["1/hour"])
800
801
        class Va(View):
802
            methods = ['GET', 'POST']
803
            decorators = [limiter.limit("2/second")]
804
805
            def dispatch_request(self):
806
                return request.method.lower()
807
808
        class Vb(View):
809
            methods = ['GET']
810
            decorators = [limiter.limit("1/second, 3/minute")]
811
812
            def dispatch_request(self):
813
                return request.method.lower()
814
815
        class Vc(View):
816
            methods = ['GET']
817
818
            def dispatch_request(self):
819
                return request.method.lower()
820
821
        app.add_url_rule("/a", view_func=Va.as_view("a"))
822
        app.add_url_rule("/b", view_func=Vb.as_view("b"))
823
        app.add_url_rule("/c", view_func=Vc.as_view("c"))
824
        with hiro.Timeline().freeze() as timeline:
825
            with app.test_client() as cli:
826
                self.assertEqual(200, cli.get("/a").status_code)
827
                self.assertEqual(200, cli.get("/a").status_code)
828
                self.assertEqual(429, cli.post("/a").status_code)
829
                self.assertEqual(200, cli.get("/b").status_code)
830
                timeline.forward(1)
831
                self.assertEqual(200, cli.get("/b").status_code)
832
                timeline.forward(1)
833
                self.assertEqual(200, cli.get("/b").status_code)
834
                timeline.forward(1)
835
                self.assertEqual(429, cli.get("/b").status_code)
836
                self.assertEqual(200, cli.get("/c").status_code)
837
                self.assertEqual(429, cli.get("/c").status_code)
838
839
    def test_pluggable_method_views(self):
840
        app, limiter = self.build_app(default_limits=["1/hour"])
841
842
        class Va(MethodView):
843
            decorators = [limiter.limit("2/second")]
844
845
            def get(self):
846
                return request.method.lower()
847
848
            def post(self):
849
                return request.method.lower()
850
851
        class Vb(MethodView):
852
            decorators = [limiter.limit("1/second, 3/minute")]
853
854
            def get(self):
855
                return request.method.lower()
856
857
        class Vc(MethodView):
858
            def get(self):
859
                return request.method.lower()
860
861
        class Vd(MethodView):
862
            decorators = [limiter.limit("1/minute", methods=['get'])]
863
864
            def get(self):
865
                return request.method.lower()
866
867
            def post(self):
868
                return request.method.lower()
869
870
        app.add_url_rule("/a", view_func=Va.as_view("a"))
871
        app.add_url_rule("/b", view_func=Vb.as_view("b"))
872
        app.add_url_rule("/c", view_func=Vc.as_view("c"))
873
        app.add_url_rule("/d", view_func=Vd.as_view("d"))
874
875
        with hiro.Timeline().freeze() as timeline:
876
            with app.test_client() as cli:
877
                self.assertEqual(200, cli.get("/a").status_code)
878
                self.assertEqual(200, cli.get("/a").status_code)
879
                self.assertEqual(429, cli.get("/a").status_code)
880
                self.assertEqual(429, cli.post("/a").status_code)
881
                self.assertEqual(200, cli.get("/b").status_code)
882
                timeline.forward(1)
883
                self.assertEqual(200, cli.get("/b").status_code)
884
                timeline.forward(1)
885
                self.assertEqual(200, cli.get("/b").status_code)
886
                timeline.forward(1)
887
                self.assertEqual(429, cli.get("/b").status_code)
888
                self.assertEqual(200, cli.get("/c").status_code)
889
                self.assertEqual(429, cli.get("/c").status_code)
890
                self.assertEqual(200, cli.get("/d").status_code)
891
                self.assertEqual(429, cli.get("/d").status_code)
892
                self.assertEqual(200, cli.post("/d").status_code)
893
                self.assertEqual(200, cli.post("/d").status_code)
894
895
    def test_flask_restful_resource(self):
896
        app, limiter = self.build_app(default_limits=["1/hour"])
897
        api = RestfulApi(app)
898
899
        class Va(Resource):
900
            decorators = [limiter.limit("2/second")]
901
902
            def get(self):
903
                return request.method.lower()
904
905
            def post(self):
906
                return request.method.lower()
907
908
        class Vb(Resource):
909
            decorators = [limiter.limit("1/second, 3/minute")]
910
911
            def get(self):
912
                return request.method.lower()
913
914
        class Vc(Resource):
915
            def get(self):
916
                return request.method.lower()
917
918
        api.add_resource(Va, "/a")
919
        api.add_resource(Vb, "/b")
920
        api.add_resource(Vc, "/c")
921
922
        with hiro.Timeline().freeze() as timeline:
923
            with app.test_client() as cli:
924
                self.assertEqual(200, cli.get("/a").status_code)
925
                self.assertEqual(200, cli.get("/a").status_code)
926
                self.assertEqual(429, cli.get("/a").status_code)
927
                self.assertEqual(429, cli.post("/a").status_code)
928
                self.assertEqual(200, cli.get("/b").status_code)
929
                timeline.forward(1)
930
                self.assertEqual(200, cli.get("/b").status_code)
931
                timeline.forward(1)
932
                self.assertEqual(200, cli.get("/b").status_code)
933
                timeline.forward(1)
934
                self.assertEqual(429, cli.get("/b").status_code)
935
                self.assertEqual(200, cli.get("/c").status_code)
936
                self.assertEqual(429, cli.get("/c").status_code)
937
938
939
class FlaskExtTests(FlaskLimiterTestCase):
940
    def test_reset(self):
941
        app, limiter = self.build_app({C.GLOBAL_LIMITS: "1 per day"})
942
943
        @app.route("/")
944
        def null():
945
            return "Hello Reset"
946
947
        with app.test_client() as cli:
948
            cli.get("/")
949
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
950
            limiter.reset()
951
            self.assertEqual("Hello Reset", cli.get("/").data.decode())
952
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
953
954 View Code Duplication
    def test_combined_rate_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
955
        app, limiter = self.build_app({
956
            C.GLOBAL_LIMITS: "1 per hour; 10 per day"
957
        })
958
959
        @app.route("/t1")
960
        @limiter.limit("100 per hour;10/minute")
961
        def t1():
962
            return "t1"
963
964
        @app.route("/t2")
965
        def t2():
966
            return "t2"
967
968
        with hiro.Timeline().freeze() as timeline:
969
            with app.test_client() as cli:
970
                self.assertEqual(200, cli.get("/t1").status_code)
971
                self.assertEqual(200, cli.get("/t2").status_code)
972
                self.assertEqual(429, cli.get("/t2").status_code)
973
974
    def test_key_func(self):
975
        app, limiter = self.build_app()
976
977
        @app.route("/t1")
978
        @limiter.limit("100 per minute", lambda: "test")
979
        def t1():
980
            return "test"
981
982
        with hiro.Timeline().freeze() as timeline:
983
            with app.test_client() as cli:
984
                for i in range(0, 100):
985
                    self.assertEqual(
986
                        200,
987
                        cli.get(
988
                            "/t1", headers={
989
                                "X_FORWARDED_FOR": "127.0.0.2"
990
                            }
991
                        ).status_code
992
                    )
993
                self.assertEqual(429, cli.get("/t1").status_code)
994
995
    def test_logging(self):
996
        app = Flask(__name__)
997
        limiter = Limiter(app, key_func=get_remote_address)
998
        mock_handler = mock.Mock()
999
        mock_handler.level = logging.INFO
1000
        limiter.logger.addHandler(mock_handler)
1001
1002
        @app.route("/t1")
1003
        @limiter.limit("1/minute")
1004
        def t1():
1005
            return "test"
1006
1007
        with app.test_client() as cli:
1008
            self.assertEqual(200, cli.get("/t1").status_code)
1009
            self.assertEqual(429, cli.get("/t1").status_code)
1010
        self.assertEqual(mock_handler.handle.call_count, 1)
1011
1012
    def test_reuse_logging(self):
1013
        app = Flask(__name__)
1014
        app_handler = mock.Mock()
1015
        app_handler.level = logging.INFO
1016
        app.logger.addHandler(app_handler)
1017
        limiter = Limiter(app, key_func=get_remote_address)
1018
        for handler in app.logger.handlers:
1019
            limiter.logger.addHandler(handler)
1020
1021
        @app.route("/t1")
1022
        @limiter.limit("1/minute")
1023
        def t1():
1024
            return "42"
1025
1026
        with app.test_client() as cli:
1027
            cli.get("/t1")
1028
            cli.get("/t1")
1029
1030
        self.assertEqual(app_handler.handle.call_count, 1)
1031
1032 View Code Duplication
    def test_disabled_flag(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1033
        app, limiter = self.build_app(
1034
            config={C.ENABLED: False}, default_limits=["1/minute"]
1035
        )
1036
1037
        @app.route("/t1")
1038
        def t1():
1039
            return "test"
1040
1041
        @app.route("/t2")
1042
        @limiter.limit("10 per minute")
1043
        def t2():
1044
            return "test"
1045
1046
        with app.test_client() as cli:
1047
            self.assertEqual(cli.get("/t1").status_code, 200)
1048
            self.assertEqual(cli.get("/t1").status_code, 200)
1049
            for i in range(0, 10):
1050
                self.assertEqual(cli.get("/t2").status_code, 200)
1051
            self.assertEqual(cli.get("/t2").status_code, 200)
1052
1053
1054
    def test_multiple_apps(self):
1055
        app1 = Flask(__name__)
1056
        app2 = Flask(__name__)
1057
1058
        limiter = Limiter(
1059
            default_limits=["1/second"], key_func=get_remote_address
1060
        )
1061
        limiter.init_app(app1)
1062
        limiter.init_app(app2)
1063
1064
        @app1.route("/ping")
1065
        def ping():
1066
            return "PONG"
1067
1068
        @app1.route("/slowping")
1069
        @limiter.limit("1/minute")
1070
        def slow_ping():
1071
            return "PONG"
1072
1073
        @app2.route("/ping")
1074
        @limiter.limit("2/second")
1075
        def ping_2():
1076
            return "PONG"
1077
1078
        @app2.route("/slowping")
1079
        @limiter.limit("2/minute")
1080
        def slow_ping_2():
1081
            return "PONG"
1082
1083
        with hiro.Timeline().freeze() as timeline:
1084
            with app1.test_client() as cli:
1085
                self.assertEqual(cli.get("/ping").status_code, 200)
1086
                self.assertEqual(cli.get("/ping").status_code, 429)
1087
                timeline.forward(1)
1088
                self.assertEqual(cli.get("/ping").status_code, 200)
1089
                self.assertEqual(cli.get("/slowping").status_code, 200)
1090
                timeline.forward(59)
1091
                self.assertEqual(cli.get("/slowping").status_code, 429)
1092
                timeline.forward(1)
1093
                self.assertEqual(cli.get("/slowping").status_code, 200)
1094
            with app2.test_client() as cli:
1095
                self.assertEqual(cli.get("/ping").status_code, 200)
1096
                self.assertEqual(cli.get("/ping").status_code, 200)
1097
                self.assertEqual(cli.get("/ping").status_code, 429)
1098
                timeline.forward(1)
1099
                self.assertEqual(cli.get("/ping").status_code, 200)
1100
                self.assertEqual(cli.get("/slowping").status_code, 200)
1101
                timeline.forward(59)
1102
                self.assertEqual(cli.get("/slowping").status_code, 200)
1103
                self.assertEqual(cli.get("/slowping").status_code, 429)
1104
                timeline.forward(1)
1105
                self.assertEqual(cli.get("/slowping").status_code, 200)
1106
1107
    def test_headers_no_breach(self):
1108
        app = Flask(__name__)
1109
        limiter = Limiter(
1110
            app,
1111
            default_limits=["10/minute"],
1112
            headers_enabled=True,
1113
            key_func=get_remote_address
1114
        )
1115
1116
        @app.route("/t1")
1117
        def t1():
1118
            return "test"
1119
1120
        @app.route("/t2")
1121
        @limiter.limit("2/second; 5 per minute; 10/hour")
1122
        def t2():
1123
            return "test"
1124
1125
        with hiro.Timeline().freeze():
1126
            with app.test_client() as cli:
1127
                resp = cli.get("/t1")
1128
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '10')
1129
                self.assertEqual(
1130
                    resp.headers.get('X-RateLimit-Remaining'), '9'
1131
                )
1132
                self.assertEqual(
1133
                    resp.headers.get('X-RateLimit-Reset'),
1134
                    str(int(time.time() + 61))
1135
                )
1136
                self.assertEqual(resp.headers.get('Retry-After'), str(60))
1137
                resp = cli.get("/t2")
1138
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '2')
1139
                self.assertEqual(
1140
                    resp.headers.get('X-RateLimit-Remaining'), '1'
1141
                )
1142
                self.assertEqual(
1143
                    resp.headers.get('X-RateLimit-Reset'),
1144
                    str(int(time.time() + 2))
1145
                )
1146
1147
                self.assertEqual(resp.headers.get('Retry-After'), str(1))
1148
1149
    def test_headers_breach(self):
1150
        app = Flask(__name__)
1151
        limiter = Limiter(
1152
            app,
1153
            default_limits=["10/minute"],
1154
            headers_enabled=True,
1155
            key_func=get_remote_address
1156
        )
1157
1158
        @app.route("/t1")
1159
        @limiter.limit("2/second; 10 per minute; 20/hour")
1160
        def t():
1161
            return "test"
1162
1163
        with hiro.Timeline().freeze() as timeline:
1164
            with app.test_client() as cli:
1165
                for i in range(11):
1166
                    resp = cli.get("/t1")
1167
                    timeline.forward(1)
1168
1169
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '10')
1170
                self.assertEqual(
1171
                    resp.headers.get('X-RateLimit-Remaining'), '0'
1172
                )
1173
                self.assertEqual(
1174
                    resp.headers.get('X-RateLimit-Reset'),
1175
                    str(int(time.time() + 50))
1176
                )
1177
                self.assertEqual(resp.headers.get('Retry-After'), str(int(50)))
1178
1179
    def test_retry_after(self):
1180
        app = Flask(__name__)
1181
        _ = Limiter(
1182
            app,
1183
            default_limits=["1/minute"],
1184
            headers_enabled=True,
1185
            key_func=get_remote_address
1186
        )
1187
1188
        @app.route("/t1")
1189
        def t():
1190
            return "test"
1191
1192
        with hiro.Timeline().freeze() as timeline:
1193
            with app.test_client() as cli:
1194
                resp = cli.get("/t1")
1195
                retry_after = int(resp.headers.get('Retry-After'))
1196
                self.assertTrue(retry_after > 0)
1197
                timeline.forward(retry_after)
1198
                resp = cli.get("/t1")
1199
                self.assertEqual(resp.status_code, 200)
1200
1201
    def test_custom_headers_from_setter(self):
1202
        app = Flask(__name__)
1203
        limiter = Limiter(
1204
            app,
1205
            default_limits=["10/minute"],
1206
            headers_enabled=True,
1207
            key_func=get_remote_address,
1208
            retry_after='http-date'
1209
        )
1210
        limiter._header_mapping[HEADERS.RESET] = 'X-Reset'
1211
        limiter._header_mapping[HEADERS.LIMIT] = 'X-Limit'
1212
        limiter._header_mapping[HEADERS.REMAINING] = 'X-Remaining'
1213
1214
        @app.route("/t1")
1215
        @limiter.limit("2/second; 10 per minute; 20/hour")
1216
        def t():
1217
            return "test"
1218
1219
        with hiro.Timeline().freeze(0) as timeline:
1220
            with app.test_client() as cli:
1221
                for i in range(11):
1222
                    resp = cli.get("/t1")
1223
                    timeline.forward(1)
1224
1225
                self.assertEqual(resp.headers.get('X-Limit'), '10')
1226
                self.assertEqual(resp.headers.get('X-Remaining'), '0')
1227
                self.assertEqual(
1228
                    resp.headers.get('X-Reset'), str(int(time.time() + 50))
1229
                )
1230
                self.assertEqual(
1231
                    resp.headers.get('Retry-After'),
1232
                    'Thu, 01 Jan 1970 00:01:01 GMT'
1233
                )
1234
1235
    def test_custom_headers_from_config(self):
1236
        app = Flask(__name__)
1237
        app.config.setdefault(C.HEADER_LIMIT, "X-Limit")
1238
        app.config.setdefault(C.HEADER_REMAINING, "X-Remaining")
1239
        app.config.setdefault(C.HEADER_RESET, "X-Reset")
1240
        limiter = Limiter(
1241
            app,
1242
            default_limits=["10/minute"],
1243
            headers_enabled=True,
1244
            key_func=get_remote_address
1245
        )
1246
1247
        @app.route("/t1")
1248
        @limiter.limit("2/second; 10 per minute; 20/hour")
1249
        def t():
1250
            return "test"
1251
1252
        with hiro.Timeline().freeze() as timeline:
1253
            with app.test_client() as cli:
1254
                for i in range(11):
1255
                    resp = cli.get("/t1")
1256
                    timeline.forward(1)
1257
1258
                self.assertEqual(resp.headers.get('X-Limit'), '10')
1259
                self.assertEqual(resp.headers.get('X-Remaining'), '0')
1260
                self.assertEqual(
1261
                    resp.headers.get('X-Reset'), str(int(time.time() + 50))
1262
                )
1263
1264 View Code Duplication
    def test_application_shared_limit(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1265
        app, limiter = self.build_app(application_limits=["2/minute"])
1266
1267
        @app.route("/t1")
1268
        def t1():
1269
            return "route1"
1270
1271
        @app.route("/t2")
1272
        def t2():
1273
            return "route2"
1274
1275
        with hiro.Timeline().freeze():
1276
            with app.test_client() as cli:
1277
                self.assertEqual(200, cli.get("/t1").status_code)
1278
                self.assertEqual(200, cli.get("/t2").status_code)
1279
                self.assertEqual(429, cli.get("/t1").status_code)
1280
1281 View Code Duplication
    def test_callable_default_limit(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1282
        app, limiter = self.build_app(default_limits=[lambda: "1/minute"])
1283
1284
        @app.route("/t1")
1285
        def t1():
1286
            return "t1"
1287
1288
        @app.route("/t2")
1289
        def t2():
1290
            return "t2"
1291
1292
        with hiro.Timeline().freeze():
1293
            with app.test_client() as cli:
1294
                self.assertEqual(cli.get("/t1").status_code, 200)
1295
                self.assertEqual(cli.get("/t2").status_code, 200)
1296
                self.assertEqual(cli.get("/t1").status_code, 429)
1297
                self.assertEqual(cli.get("/t2").status_code, 429)
1298
1299 View Code Duplication
    def test_callable_application_limit(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1300
1301
        app, limiter = self.build_app(application_limits=[lambda: "1/minute"])
1302
1303
        @app.route("/t1")
1304
        def t1():
1305
            return "t1"
1306
1307
        @app.route("/t2")
1308
        def t2():
1309
            return "t2"
1310
1311
        with hiro.Timeline().freeze():
1312
            with app.test_client() as cli:
1313
                self.assertEqual(cli.get("/t1").status_code, 200)
1314
                self.assertEqual(cli.get("/t2").status_code, 429)
1315
1316
    def test_no_auto_check(self):
1317
        app, limiter = self.build_app(auto_check=False)
1318
1319
        @limiter.limit("1/second", per_method=True)
1320
        @app.route("/", methods=["GET", "POST"])
1321
        def root():
1322
            return "root"
1323
1324
        with hiro.Timeline().freeze():
1325
            with app.test_client() as cli:
1326
                self.assertEqual(200, cli.get("/").status_code)
1327
                self.assertEqual(200, cli.get("/").status_code)
1328
1329
        # attach before_request to perform check
1330
        @app.before_request
1331
        def _():
1332
            limiter.check()
1333
1334
        with hiro.Timeline().freeze():
1335
            with app.test_client() as cli:
1336
                self.assertEqual(200, cli.get("/").status_code)
1337
                self.assertEqual(429, cli.get("/").status_code)
1338
1339
    def test_custom_key_prefix(self):
1340
        app1, limiter1 = self.build_app(
1341
            key_prefix="moo", storage_uri="redis://localhost:6379"
1342
        )
1343
        app2, limiter2 = self.build_app({
1344
            C.KEY_PREFIX: "cow"
1345
        },
1346
                                        storage_uri="redis://localhost:6379")
1347
        app3, limiter3 = self.build_app(storage_uri="redis://localhost:6379")
1348
1349
        @app1.route("/test")
1350
        @limiter1.limit("1/day")
1351
        def t1():
1352
            return "app1 test"
1353
1354
        @app2.route("/test")
1355
        @limiter2.limit("1/day")
1356
        def t1():
1357
            return "app1 test"
1358
1359
        @app3.route("/test")
1360
        @limiter3.limit("1/day")
1361
        def t1():
1362
            return "app1 test"
1363
1364
        with app1.test_client() as cli:
1365
            resp = cli.get("/test")
1366
            self.assertEqual(200, resp.status_code)
1367
            resp = cli.get("/test")
1368
            self.assertEqual(429, resp.status_code)
1369
        with app2.test_client() as cli:
1370
            resp = cli.get("/test")
1371
            self.assertEqual(200, resp.status_code)
1372
            resp = cli.get("/test")
1373
            self.assertEqual(429, resp.status_code)
1374
        with app3.test_client() as cli:
1375
            resp = cli.get("/test")
1376
            self.assertEqual(200, resp.status_code)
1377
            resp = cli.get("/test")
1378
            self.assertEqual(429, resp.status_code)
1379