ErrorHandlingTests.null()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 3
Ratio 100 %

Importance

Changes 0
Metric Value
cc 1
c 0
b 0
f 0
dl 3
loc 3
rs 10
1
"""
2
3
"""
4
import json
5
import logging
6
import time
7
import unittest
8
from functools import wraps
9
10
import functools
11
import hiro
12
import mock
13
import redis
14
import datetime
15
from flask import Flask, Blueprint, request, current_app, make_response, g
16
from flask_restful import Resource, Api as RestfulApi
17
from flask.views import View, MethodView
18
from limits.errors import ConfigurationError
19
from limits.storage import MemcachedStorage
20
from limits.strategies import MovingWindowRateLimiter
21
22
from flask_limiter.extension import C, Limiter, HEADERS
23
from flask_limiter.util import get_remote_address, get_ipaddr
24
from tests import FlaskLimiterTestCase
25
26
27
class ConfigurationTests(FlaskLimiterTestCase):
28
    def test_invalid_strategy(self):
29
        app = Flask(__name__)
30
        app.config.setdefault(C.STRATEGY, "fubar")
31
        self.assertRaises(
32
            ConfigurationError, Limiter, app, key_func=get_remote_address
33
        )
34
35
    def test_invalid_storage_string(self):
36
        app = Flask(__name__)
37
        app.config.setdefault(C.STORAGE_URL, "fubar://localhost:1234")
38
        self.assertRaises(
39
            ConfigurationError, Limiter, app, key_func=get_remote_address
40
        )
41
42
    def test_constructor_arguments_over_config(self):
43
        app = Flask(__name__)
44
        app.config.setdefault(C.STRATEGY, "fixed-window-elastic-expiry")
45
        limiter = Limiter(
46
            strategy='moving-window', key_func=get_remote_address
47
        )
48
        limiter.init_app(app)
49
        app.config.setdefault(C.STORAGE_URL, "redis://localhost:6379")
50
        self.assertEqual(type(limiter._limiter), MovingWindowRateLimiter)
51
        limiter = Limiter(
52
            storage_uri='memcached://localhost:11211',
53
            key_func=get_remote_address
54
        )
55
        limiter.init_app(app)
56
        self.assertEqual(type(limiter._storage), MemcachedStorage)
57
58
59
class ErrorHandlingTests(FlaskLimiterTestCase):
60 View Code Duplication
    def test_error_message(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
61
        app, limiter = self.build_app({C.GLOBAL_LIMITS: "1 per day"})
62
63
        @app.route("/")
64
        def null():
65
            return ""
66
67
        with app.test_client() as cli:
68
            cli.get("/")
69
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
70
71
            @app.errorhandler(429)
72
            def ratelimit_handler(e):
73
                return make_response(
74
                    '{"error" : "rate limit %s"}' % str(e.description), 429
75
                )
76
77
            self.assertEqual({
78
                'error': 'rate limit 1 per 1 day'
79
            }, json.loads(cli.get("/").data.decode()))
80
81
    def test_custom_error_message(self):
82
        app, limiter = self.build_app()
83
84
        @app.errorhandler(429)
85
        def ratelimit_handler(e):
86
            return make_response(e.description, 429)
87
88
        l1 = lambda: "1/second"
89
        e1 = lambda: "dos"
90
91
        @limiter.limit("1/second", error_message="uno")
92
        @app.route("/t1")
93
        def t1():
94
            return "1"
95
96
        @limiter.limit(l1, error_message=e1)
97
        @app.route("/t2")
98
        def t2():
99
            return "2"
100
101
        s1 = limiter.shared_limit(
102
            "1/second", scope='error_message', error_message="tres"
103
        )
104
105
        @app.route("/t3")
106
        @s1
107
        def t3():
108
            return "3"
109
110
        with hiro.Timeline().freeze():
111
            with app.test_client() as cli:
112
                cli.get("/t1")
113
                resp = cli.get("/t1")
114
                self.assertEqual(429, resp.status_code)
115
                self.assertEqual(resp.data, b'uno')
116
                cli.get("/t2")
117
                resp = cli.get("/t2")
118
                self.assertEqual(429, resp.status_code)
119
                self.assertEqual(resp.data, b'dos')
120
                cli.get("/t3")
121
                resp = cli.get("/t3")
122
                self.assertEqual(429, resp.status_code)
123
                self.assertEqual(resp.data, b'tres')
124
125
    def test_swallow_error(self):
126
        app, limiter = self.build_app({
127
            C.GLOBAL_LIMITS: "1 per day",
128
            C.SWALLOW_ERRORS: True
129
        })
130
131
        @app.route("/")
132
        def null():
133
            return "ok"
134
135
        with app.test_client() as cli:
136
            with mock.patch(
137
                "limits.strategies.FixedWindowRateLimiter.hit"
138
            ) as hit:
139
140
                def raiser(*a, **k):
141
                    raise Exception
142
143
                hit.side_effect = raiser
144
                self.assertTrue("ok" in cli.get("/").data.decode())
145
146 View Code Duplication
    def test_no_swallow_error(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
147
        app, limiter = self.build_app({
148
            C.GLOBAL_LIMITS: "1 per day",
149
        })
150
151
        @app.route("/")
152
        def null():
153
            return "ok"
154
155
        @app.errorhandler(500)
156
        def e500(e):
157
            return str(e), 500
158
159
        with app.test_client() as cli:
160
            with mock.patch(
161
                "limits.strategies.FixedWindowRateLimiter.hit"
162
            ) as hit:
163
164
                def raiser(*a, **k):
165
                    raise Exception("underlying")
166
167
                hit.side_effect = raiser
168
                self.assertEqual(500, cli.get("/").status_code)
169
                self.assertEqual("underlying", cli.get("/").data.decode())
170
171
    def test_fallback_to_memory_config(self):
172
        _, limiter = self.build_app(
173
            config={C.ENABLED: True},
174
            default_limits=["5/minute"],
175
            storage_uri="redis://localhost:6379",
176
            in_memory_fallback=["1/minute"]
177
        )
178
        self.assertEqual(len(limiter._in_memory_fallback), 1)
179
180
        _, limiter = self.build_app(
181
            config={C.ENABLED: True,
182
                    C.IN_MEMORY_FALLBACK: "1/minute"},
183
            default_limits=["5/minute"],
184
            storage_uri="redis://localhost:6379",
185
        )
186
        self.assertEqual(len(limiter._in_memory_fallback), 1)
187
188
    def test_fallback_to_memory_backoff_check(self):
189
        app, limiter = self.build_app(
190
            config={C.ENABLED: True},
191
            default_limits=["5/minute"],
192
            storage_uri="redis://localhost:6379",
193
            in_memory_fallback=["1/minute"]
194
        )
195
196
        @app.route("/t1")
197
        def t1():
198
            return "test"
199
200
        with app.test_client() as cli:
201
202
            def raiser(*a):
203
                raise Exception("redis dead")
204
205
            with hiro.Timeline() as timeline:
206
                with mock.patch(
207
                    "redis.client.Redis.execute_command"
208
                ) as exec_command:
209
                    exec_command.side_effect = raiser
210
                    self.assertEqual(cli.get("/t1").status_code, 200)
211
                    self.assertEqual(cli.get("/t1").status_code, 429)
212
                    timeline.forward(1)
213
                    self.assertEqual(cli.get("/t1").status_code, 429)
214
                    timeline.forward(2)
215
                    self.assertEqual(cli.get("/t1").status_code, 429)
216
                    timeline.forward(4)
217
                    self.assertEqual(cli.get("/t1").status_code, 429)
218
                    timeline.forward(8)
219
                    self.assertEqual(cli.get("/t1").status_code, 429)
220
                    timeline.forward(16)
221
                    self.assertEqual(cli.get("/t1").status_code, 429)
222
                    timeline.forward(32)
223
                    self.assertEqual(cli.get("/t1").status_code, 200)
224
                # redis back to normal, but exponential backoff will only
225
                # result in it being marked after pow(2,0) seconds and next
226
                # check
227
                self.assertEqual(cli.get("/t1").status_code, 429)
228
                timeline.forward(2)
229
                self.assertEqual(cli.get("/t1").status_code, 200)
230
                self.assertEqual(cli.get("/t1").status_code, 200)
231
                self.assertEqual(cli.get("/t1").status_code, 200)
232
                self.assertEqual(cli.get("/t1").status_code, 200)
233
                self.assertEqual(cli.get("/t1").status_code, 200)
234
                self.assertEqual(cli.get("/t1").status_code, 429)
235
236
    def test_fallback_to_memory(self):
237
        app, limiter = self.build_app(
238
            config={C.ENABLED: True},
239
            default_limits=["5/minute"],
240
            storage_uri="redis://localhost:6379",
241
            in_memory_fallback=["1/minute"]
242
        )
243
244
        @app.route("/t1")
245
        def t1():
246
            return "test"
247
248
        @app.route("/t2")
249
        @limiter.limit("3 per minute")
250
        def t2():
251
            return "test"
252
253
        with app.test_client() as cli:
254
            self.assertEqual(cli.get("/t1").status_code, 200)
255
            self.assertEqual(cli.get("/t1").status_code, 200)
256
            self.assertEqual(cli.get("/t1").status_code, 200)
257
            self.assertEqual(cli.get("/t1").status_code, 200)
258
            self.assertEqual(cli.get("/t1").status_code, 200)
259
            self.assertEqual(cli.get("/t1").status_code, 429)
260
            self.assertEqual(cli.get("/t2").status_code, 200)
261
            self.assertEqual(cli.get("/t2").status_code, 200)
262
            self.assertEqual(cli.get("/t2").status_code, 200)
263
            self.assertEqual(cli.get("/t2").status_code, 429)
264
265
            def raiser(*a):
266
                raise Exception("redis dead")
267
268
            with mock.patch(
269
                "redis.client.Redis.execute_command"
270
            ) as exec_command:
271
                exec_command.side_effect = raiser
272
                self.assertEqual(cli.get("/t1").status_code, 200)
273
                self.assertEqual(cli.get("/t1").status_code, 429)
274
                self.assertEqual(cli.get("/t2").status_code, 200)
275
                self.assertEqual(cli.get("/t2").status_code, 429)
276
            # redis back to normal, go back to regular limits
277
            with hiro.Timeline() as timeline:
278
                timeline.forward(2)
279
                limiter._storage.storage.flushall()
280
                self.assertEqual(cli.get("/t2").status_code, 200)
281
                self.assertEqual(cli.get("/t2").status_code, 200)
282
                self.assertEqual(cli.get("/t2").status_code, 200)
283
                self.assertEqual(cli.get("/t2").status_code, 429)
284
285
286
class DecoratorTests(FlaskLimiterTestCase):
287
    def test_multiple_decorators(self):
288
        app, limiter = self.build_app(key_func=get_ipaddr)
289
290
        @app.route("/t1")
291
        @limiter.limit(
292
            "100 per minute", lambda: "test"
293
        )  # effectively becomes a limit for all users
294
        @limiter.limit("50/minute")  # per ip as per default key_func
295
        def t1():
296
            return "test"
297
298
        with hiro.Timeline().freeze() as timeline:
299
            with app.test_client() as cli:
300
                for i in range(0, 100):
301
                    self.assertEqual(
302
                        200 if i < 50 else 429,
303
                        cli.get(
304
                            "/t1", headers={
305
                                "X_FORWARDED_FOR": "127.0.0.2"
306
                            }
307
                        ).status_code
308
                    )
309
                for i in range(50):
310
                    self.assertEqual(200, cli.get("/t1").status_code)
311
                self.assertEqual(429, cli.get("/t1").status_code)
312
                self.assertEqual(
313
                    429,
314
                    cli.get("/t1", headers={
315
                        "X_FORWARDED_FOR": "127.0.0.3"
316
                    }).status_code
317
                )
318
319 View Code Duplication
    def test_exempt_routes(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
320
        app, limiter = self.build_app(default_limits=["1/minute"])
321
322
        @app.route("/t1")
323
        def t1():
324
            return "test"
325
326
        @app.route("/t2")
327
        @limiter.exempt
328
        def t2():
329
            return "test"
330
331
        with app.test_client() as cli:
332
            self.assertEqual(cli.get("/t1").status_code, 200)
333
            self.assertEqual(cli.get("/t1").status_code, 429)
334
            self.assertEqual(cli.get("/t2").status_code, 200)
335
            self.assertEqual(cli.get("/t2").status_code, 200)
336
337
    def test_decorated_dynamic_limits(self):
338
        app, limiter = self.build_app({
339
            "X": "2 per second"
340
        },
341
                                      default_limits=["1/second"])
342
343
        def request_context_limit():
344
            limits = {
345
                "127.0.0.1": "10 per minute",
346
                "127.0.0.2": "1 per minute"
347
            }
348
            remote_addr = (request.access_route and request.access_route[0]
349
                           ) or request.remote_addr or '127.0.0.1'
350
            limit = limits.setdefault(remote_addr, '1 per minute')
351
            return limit
352
353
        @app.route("/t1")
354
        @limiter.limit("20/day")
355
        @limiter.limit(lambda: current_app.config.get("X"))
356
        @limiter.limit(request_context_limit)
357
        def t1():
358
            return "42"
359
360
        @app.route("/t2")
361
        @limiter.limit(lambda: current_app.config.get("X"))
362
        def t2():
363
            return "42"
364
365
        R1 = {"X_FORWARDED_FOR": "127.0.0.1, 127.0.0.0"}
366
        R2 = {"X_FORWARDED_FOR": "127.0.0.2"}
367
368
        with app.test_client() as cli:
369
            with hiro.Timeline().freeze() as timeline:
370
                for i in range(0, 10):
371
                    self.assertEqual(
372
                        cli.get("/t1", headers=R1).status_code, 200
373
                    )
374
                    timeline.forward(1)
375
                self.assertEqual(cli.get("/t1", headers=R1).status_code, 429)
376
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 200)
377
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 429)
378
                timeline.forward(60)
379
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 200)
380
                self.assertEqual(cli.get("/t2").status_code, 200)
381
                self.assertEqual(cli.get("/t2").status_code, 200)
382
                self.assertEqual(cli.get("/t2").status_code, 429)
383
                timeline.forward(1)
384
                self.assertEqual(cli.get("/t2").status_code, 200)
385
386 View Code Duplication
    def test_invalid_decorated_dynamic_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
387
        app = Flask(__name__)
388
        app.config.setdefault("X", "2 per sec")
389
        limiter = Limiter(
390
            app, default_limits=["1/second"], key_func=get_remote_address
391
        )
392
        mock_handler = mock.Mock()
393
        mock_handler.level = logging.INFO
394
        limiter.logger.addHandler(mock_handler)
395
396
        @app.route("/t1")
397
        @limiter.limit(lambda: current_app.config.get("X"))
398
        def t1():
399
            return "42"
400
401
        with app.test_client() as cli:
402
            with hiro.Timeline().freeze() as timeline:
403
                self.assertEqual(cli.get("/t1").status_code, 200)
404
                self.assertEqual(cli.get("/t1").status_code, 429)
405
        # 2 for invalid limit, 1 for warning.
406
        self.assertEqual(mock_handler.handle.call_count, 3)
407
        self.assertTrue(
408
            "failed to load ratelimit" in mock_handler.handle.call_args_list[0]
409
            [0][0].msg
410
        )
411
        self.assertTrue(
412
            "failed to load ratelimit" in mock_handler.handle.call_args_list[1]
413
            [0][0].msg
414
        )
415
        self.assertTrue(
416
            "exceeded at endpoint" in mock_handler.handle.call_args_list[2][0]
417
            [0].msg
418
        )
419
420 View Code Duplication
    def test_invalid_decorated_static_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
421
        app = Flask(__name__)
422
        limiter = Limiter(
423
            app, default_limits=["1/second"], key_func=get_remote_address
424
        )
425
        mock_handler = mock.Mock()
426
        mock_handler.level = logging.INFO
427
        limiter.logger.addHandler(mock_handler)
428
429
        @app.route("/t1")
430
        @limiter.limit("2/sec")
431
        def t1():
432
            return "42"
433
434
        with app.test_client() as cli:
435
            with hiro.Timeline().freeze() as timeline:
436
                self.assertEqual(cli.get("/t1").status_code, 200)
437
                self.assertEqual(cli.get("/t1").status_code, 429)
438
        self.assertTrue(
439
            "failed to configure" in mock_handler.handle.call_args_list[0][0]
440
            [0].msg
441
        )
442
        self.assertTrue(
443
            "exceeded at endpoint" in mock_handler.handle.call_args_list[1][0]
444
            [0].msg
445
        )
446
447
    def test_named_shared_limit(self):
448
        app, limiter = self.build_app()
449
        shared_limit_a = limiter.shared_limit("1/minute", scope='a')
450
        shared_limit_b = limiter.shared_limit("1/minute", scope='b')
451
452
        @app.route("/t1")
453
        @shared_limit_a
454
        def route1():
455
            return "route1"
456
457
        @app.route("/t2")
458
        @shared_limit_a
459
        def route2():
460
            return "route2"
461
462
        @app.route("/t3")
463
        @shared_limit_b
464
        def route3():
465
            return "route3"
466
467
        with hiro.Timeline().freeze() as timeline:
468
            with app.test_client() as cli:
469
                self.assertEqual(200, cli.get("/t1").status_code)
470
                self.assertEqual(200, cli.get("/t3").status_code)
471
                self.assertEqual(429, cli.get("/t2").status_code)
472
473
    def test_dynamic_shared_limit(self):
474
        app, limiter = self.build_app()
475
        fn_a = mock.Mock()
476
        fn_b = mock.Mock()
477
        fn_a.return_value = "foo"
478
        fn_b.return_value = "bar"
479
480
        dy_limit_a = limiter.shared_limit("1/minute", scope=fn_a)
481
        dy_limit_b = limiter.shared_limit("1/minute", scope=fn_b)
482
483
        @app.route("/t1")
484
        @dy_limit_a
485
        def t1():
486
            return "route1"
487
488
        @app.route("/t2")
489
        @dy_limit_a
490
        def t2():
491
            return "route2"
492
493
        @app.route("/t3")
494
        @dy_limit_b
495
        def t3():
496
            return "route3"
497
498
        with hiro.Timeline().freeze():
499
            with app.test_client() as cli:
500
                self.assertEqual(200, cli.get("/t1").status_code)
501
                self.assertEqual(200, cli.get("/t3").status_code)
502
                self.assertEqual(429, cli.get("/t2").status_code)
503
                self.assertEqual(429, cli.get("/t3").status_code)
504
                self.assertEqual(2, fn_a.call_count)
505
                self.assertEqual(2, fn_b.call_count)
506
                fn_b.assert_called_with("t3")
507
                fn_a.assert_has_calls([mock.call("t1"), mock.call("t2")])
508
509 View Code Duplication
    def test_conditional_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
510
        """Test that the conditional activation of the limits work."""
511
        app = Flask(__name__)
512
        limiter = Limiter(app, key_func=get_remote_address)
513
514
        @app.route("/limited")
515
        @limiter.limit("1 per day")
516
        def limited_route():
517
            return "passed"
518
519
        @app.route("/unlimited")
520
        @limiter.limit("1 per day", exempt_when=lambda: True)
521
        def never_limited_route():
522
            return "should always pass"
523
524
        is_exempt = False
525
526
        @app.route("/conditional")
527
        @limiter.limit("1 per day", exempt_when=lambda: is_exempt)
528
        def conditionally_limited_route():
529
            return "conditional"
530
531
        with app.test_client() as cli:
532
            self.assertEqual(cli.get("/limited").status_code, 200)
533
            self.assertEqual(cli.get("/limited").status_code, 429)
534
535
            self.assertEqual(cli.get("/unlimited").status_code, 200)
536
            self.assertEqual(cli.get("/unlimited").status_code, 200)
537
538
            self.assertEqual(cli.get("/conditional").status_code, 200)
539
            self.assertEqual(cli.get("/conditional").status_code, 429)
540
            is_exempt = True
541
            self.assertEqual(cli.get("/conditional").status_code, 200)
542
            is_exempt = False
543
            self.assertEqual(cli.get("/conditional").status_code, 429)
544
545 View Code Duplication
    def test_conditional_shared_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
546
        """Test that conditional shared limits work."""
547
        app = Flask(__name__)
548
        limiter = Limiter(app, key_func=get_remote_address)
549
550
        @app.route("/limited")
551
        @limiter.shared_limit("1 per day", "test_scope")
552
        def limited_route():
553
            return "passed"
554
555
        @app.route("/unlimited")
556
        @limiter.shared_limit(
557
            "1 per day", "test_scope", exempt_when=lambda: True
558
        )
559
        def never_limited_route():
560
            return "should always pass"
561
562
        is_exempt = False
563
564
        @app.route("/conditional")
565
        @limiter.shared_limit(
566
            "1 per day", "test_scope", exempt_when=lambda: is_exempt
567
        )
568
        def conditionally_limited_route():
569
            return "conditional"
570
571
        with app.test_client() as cli:
572
            self.assertEqual(cli.get("/unlimited").status_code, 200)
573
            self.assertEqual(cli.get("/unlimited").status_code, 200)
574
575
            self.assertEqual(cli.get("/limited").status_code, 200)
576
            self.assertEqual(cli.get("/limited").status_code, 429)
577
578
            self.assertEqual(cli.get("/conditional").status_code, 429)
579
            is_exempt = True
580
            self.assertEqual(cli.get("/conditional").status_code, 200)
581
            is_exempt = False
582
            self.assertEqual(cli.get("/conditional").status_code, 429)
583
584
    def test_whitelisting(self):
585
586
        app = Flask(__name__)
587
        limiter = Limiter(
588
            app,
589
            default_limits=["1/minute"],
590
            headers_enabled=True,
591
            key_func=get_remote_address
592
        )
593
594
        @app.route("/")
595
        def t():
596
            return "test"
597
598
        @limiter.request_filter
599
        def w():
600
            if request.headers.get("internal", None) == "true":
601
                return True
602
            return False
603
604
        with hiro.Timeline().freeze() as timeline:
605
            with app.test_client() as cli:
606
                self.assertEqual(cli.get("/").status_code, 200)
607
                self.assertEqual(cli.get("/").status_code, 429)
608
                timeline.forward(60)
609
                self.assertEqual(cli.get("/").status_code, 200)
610
611
                for i in range(0, 10):
612
                    self.assertEqual(
613
                        cli.get("/", headers={
614
                            "internal": "true"
615
                        }).status_code, 200
616
                    )
617
618 View Code Duplication
    def test_separate_method_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
619
        app, limiter = self.build_app()
620
621
        @limiter.limit("1/second", per_method=True)
622
        @app.route("/", methods=["GET", "POST"])
623
        def root():
624
            return "root"
625
626
        with hiro.Timeline():
627
            with app.test_client() as cli:
628
                self.assertEqual(200, cli.get("/").status_code)
629
                self.assertEqual(429, cli.get("/").status_code)
630
                self.assertEqual(200, cli.post("/").status_code)
631
                self.assertEqual(429, cli.post("/").status_code)
632
633 View Code Duplication
    def test_explicit_method_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
634
        app, limiter = self.build_app()
635
636
        @limiter.limit("1/second", methods=["GET"])
637
        @app.route("/", methods=["GET", "POST"])
638
        def root():
639
            return "root"
640
641
        with hiro.Timeline():
642
            with app.test_client() as cli:
643
                self.assertEqual(200, cli.get("/").status_code)
644
                self.assertEqual(429, cli.get("/").status_code)
645
                self.assertEqual(200, cli.post("/").status_code)
646
                self.assertEqual(200, cli.post("/").status_code)
647
648
    def test_decorated_limit_immediate(self):
649
        app, limiter = self.build_app(default_limits=["1/minute"])
650
651
        def append_info(fn):
652
            @wraps(fn)
653
            def __inner(*args, **kwargs):
654
                g.rate_limit = "2/minute"
655
                return fn(*args, **kwargs)
656
            return __inner
657
658
        @app.route("/", methods=["GET", "POST"])
659
        @append_info
660
        @limiter.limit(lambda: g.rate_limit, per_method=True)
661
        def root():
662
            return "root"
663
664
        with hiro.Timeline().freeze():
665
            with app.test_client() as cli:
666
                self.assertEqual(200, cli.get("/").status_code)
667
                self.assertEqual(200, cli.get("/").status_code)
668
                self.assertEqual(429, cli.get("/").status_code)
669
670
    def test_decorated_shared_limit_immediate(self):
671
672
        app, limiter = self.build_app(default_limits=['1/minute'])
673
        shared = limiter.shared_limit(lambda: g.rate_limit, 'shared')
674
        def append_info(fn):
675
            @wraps(fn)
676
            def __inner(*args, **kwargs):
677
                g.rate_limit = "2/minute"
678
                return fn(*args, **kwargs)
679
            return __inner
680
681
        @app.route("/", methods=["GET", "POST"])
682
        @append_info
683
        @shared
684
        def root():
685
            return "root"
686
687
        @app.route("/other", methods=["GET", "POST"])
688
        def other():
689
            return "other"
690
691
        with hiro.Timeline().freeze():
692
            with app.test_client() as cli:
693
                self.assertEqual(200, cli.get("/other").status_code)
694
                self.assertEqual(429, cli.get("/other").status_code)
695
                self.assertEqual(200, cli.get("/").status_code)
696
                self.assertEqual(200, cli.get("/").status_code)
697
                self.assertEqual(429, cli.get("/").status_code)
698
699
    def test_backward_compatibility_with_incorrect_ordering(self):
700
        app, limiter = self.build_app()
701
702
        def something_else(fn):
703
            @functools.wraps(fn)
704
            def __inner(*args, **kwargs):
705
                return fn(*args, **kwargs)
706
            return __inner
707
708
        @limiter.limit("1/second")
709
        @app.route("/t1", methods=["GET", "POST"])
710
        def root():
711
            return "t1"
712
713
        @limiter.limit("1/second")
714
        @app.route("/t2", methods=["GET", "POST"])
715
        @something_else
716
        def t2():
717
            return "t2"
718
719
        @limiter.limit("2/second")
720
        @limiter.limit("1/second")
721
        @app.route("/t3", methods=["GET", "POST"])
722
        def t3():
723
            return "t3"
724
725
726
        with hiro.Timeline().freeze():
727
            with app.test_client() as cli:
728
                self.assertEqual(200, cli.get("/t1").status_code)
729
                self.assertEqual(429, cli.get("/t1").status_code)
730
                self.assertEqual(200, cli.get("/t2").status_code)
731
                self.assertEqual(429, cli.get("/t2").status_code)
732
                self.assertEqual(200, cli.get("/t3").status_code)
733
                self.assertEqual(429, cli.get("/t3").status_code)
734
735
736
class BlueprintTests(FlaskLimiterTestCase):
737 View Code Duplication
    def test_blueprint(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
738
        app, limiter = self.build_app(default_limits=["1/minute"])
739
        bp = Blueprint("main", __name__)
740
741
        @bp.route("/t1")
742
        def t1():
743
            return "test"
744
745
        @bp.route("/t2")
746
        @limiter.limit("10 per minute")
747
        def t2():
748
            return "test"
749
750
        app.register_blueprint(bp)
751
752
        with app.test_client() as cli:
753
            self.assertEqual(cli.get("/t1").status_code, 200)
754
            self.assertEqual(cli.get("/t1").status_code, 429)
755
            for i in range(0, 10):
756
                self.assertEqual(cli.get("/t2").status_code, 200)
757
            self.assertEqual(cli.get("/t2").status_code, 429)
758
759
    def test_register_blueprint(self):
760
        app, limiter = self.build_app(default_limits=["1/minute"])
761
        bp_1 = Blueprint("bp1", __name__)
762
        bp_2 = Blueprint("bp2", __name__)
763
        bp_3 = Blueprint("bp3", __name__)
764
        bp_4 = Blueprint("bp4", __name__)
765
766
        @bp_1.route("/t1")
767
        def t1():
768
            return "test"
769
770
        @bp_1.route("/t2")
771
        def t2():
772
            return "test"
773
774
        @bp_2.route("/t3")
775
        def t3():
776
            return "test"
777
778
        @bp_3.route("/t4")
779
        def t4():
780
            return "test"
781
782
        @bp_4.route("/t5")
783
        def t4():
784
            return "test"
785
786
        def dy_limit():
787
            return "1/second"
788
789
        app.register_blueprint(bp_1)
790
        app.register_blueprint(bp_2)
791
        app.register_blueprint(bp_3)
792
        app.register_blueprint(bp_4)
793
794
        limiter.limit("1/second")(bp_1)
795
        limiter.exempt(bp_3)
796
        limiter.limit(dy_limit)(bp_4)
797
798
        with hiro.Timeline().freeze() as timeline:
799
            with app.test_client() as cli:
800
                self.assertEqual(cli.get("/t1").status_code, 200)
801
                self.assertEqual(cli.get("/t1").status_code, 429)
802
                timeline.forward(1)
803
                self.assertEqual(cli.get("/t1").status_code, 200)
804
                self.assertEqual(cli.get("/t2").status_code, 200)
805
                self.assertEqual(cli.get("/t2").status_code, 429)
806
                timeline.forward(1)
807
                self.assertEqual(cli.get("/t2").status_code, 200)
808
809
                self.assertEqual(cli.get("/t3").status_code, 200)
810
                for i in range(0, 10):
811
                    timeline.forward(1)
812
                    self.assertEqual(cli.get("/t3").status_code, 429)
813
814
                for i in range(0, 10):
815
                    self.assertEqual(cli.get("/t4").status_code, 200)
816
817
                self.assertEqual(cli.get("/t5").status_code, 200)
818
                self.assertEqual(cli.get("/t5").status_code, 429)
819
820 View Code Duplication
    def test_invalid_decorated_static_limit_blueprint(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
821
        app = Flask(__name__)
822
        limiter = Limiter(
823
            app, default_limits=["1/second"], key_func=get_remote_address
824
        )
825
        mock_handler = mock.Mock()
826
        mock_handler.level = logging.INFO
827
        limiter.logger.addHandler(mock_handler)
828
        bp = Blueprint("bp1", __name__)
829
830
        @bp.route("/t1")
831
        def t1():
832
            return "42"
833
834
        limiter.limit("2/sec")(bp)
835
        app.register_blueprint(bp)
836
837
        with app.test_client() as cli:
838
            with hiro.Timeline().freeze() as timeline:
839
                self.assertEqual(cli.get("/t1").status_code, 200)
840
                self.assertEqual(cli.get("/t1").status_code, 429)
841
        self.assertTrue(
842
            "failed to configure" in mock_handler.handle.call_args_list[0][0]
843
            [0].msg
844
        )
845
        self.assertTrue(
846
            "exceeded at endpoint" in mock_handler.handle.call_args_list[1][0]
847
            [0].msg
848
        )
849
850 View Code Duplication
    def test_invalid_decorated_dynamic_limits_blueprint(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
851
        app = Flask(__name__)
852
        app.config.setdefault("X", "2 per sec")
853
        limiter = Limiter(
854
            app, default_limits=["1/second"], key_func=get_remote_address
855
        )
856
        mock_handler = mock.Mock()
857
        mock_handler.level = logging.INFO
858
        limiter.logger.addHandler(mock_handler)
859
        bp = Blueprint("bp1", __name__)
860
861
        @bp.route("/t1")
862
        def t1():
863
            return "42"
864
865
        limiter.limit(lambda: current_app.config.get("X"))(bp)
866
        app.register_blueprint(bp)
867
868
        with app.test_client() as cli:
869
            with hiro.Timeline().freeze() as timeline:
870
                self.assertEqual(cli.get("/t1").status_code, 200)
871
                self.assertEqual(cli.get("/t1").status_code, 429)
872
        self.assertEqual(mock_handler.handle.call_count, 3)
873
        self.assertTrue(
874
            "failed to load ratelimit" in mock_handler.handle.call_args_list[0]
875
            [0][0].msg
876
        )
877
        self.assertTrue(
878
            "failed to load ratelimit" in mock_handler.handle.call_args_list[1]
879
            [0][0].msg
880
        )
881
        self.assertTrue(
882
            "exceeded at endpoint" in mock_handler.handle.call_args_list[2][0]
883
            [0].msg
884
        )
885
886
class ViewsTests(FlaskLimiterTestCase):
887
    def test_pluggable_views(self):
888
        app, limiter = self.build_app(default_limits=["1/hour"])
889
890
        class Va(View):
891
            methods = ['GET', 'POST']
892
            decorators = [limiter.limit("2/second")]
893
894
            def dispatch_request(self):
895
                return request.method.lower()
896
897
        class Vb(View):
898
            methods = ['GET']
899
            decorators = [limiter.limit("1/second, 3/minute")]
900
901
            def dispatch_request(self):
902
                return request.method.lower()
903
904
        class Vc(View):
905
            methods = ['GET']
906
907
            def dispatch_request(self):
908
                return request.method.lower()
909
910
        app.add_url_rule("/a", view_func=Va.as_view("a"))
911
        app.add_url_rule("/b", view_func=Vb.as_view("b"))
912
        app.add_url_rule("/c", view_func=Vc.as_view("c"))
913
        with hiro.Timeline().freeze() as timeline:
914
            with app.test_client() as cli:
915
                self.assertEqual(200, cli.get("/a").status_code)
916
                self.assertEqual(200, cli.get("/a").status_code)
917
                self.assertEqual(429, cli.post("/a").status_code)
918
                self.assertEqual(200, cli.get("/b").status_code)
919
                timeline.forward(1)
920
                self.assertEqual(200, cli.get("/b").status_code)
921
                timeline.forward(1)
922
                self.assertEqual(200, cli.get("/b").status_code)
923
                timeline.forward(1)
924
                self.assertEqual(429, cli.get("/b").status_code)
925
                self.assertEqual(200, cli.get("/c").status_code)
926
                self.assertEqual(429, cli.get("/c").status_code)
927
928
    def test_pluggable_method_views(self):
929
        app, limiter = self.build_app(default_limits=["1/hour"])
930
931
        class Va(MethodView):
932
            decorators = [limiter.limit("2/second")]
933
934
            def get(self):
935
                return request.method.lower()
936
937
            def post(self):
938
                return request.method.lower()
939
940
        class Vb(MethodView):
941
            decorators = [limiter.limit("1/second, 3/minute")]
942
943
            def get(self):
944
                return request.method.lower()
945
946
        class Vc(MethodView):
947
            def get(self):
948
                return request.method.lower()
949
950
        class Vd(MethodView):
951
            decorators = [limiter.limit("1/minute", methods=['get'])]
952
953
            def get(self):
954
                return request.method.lower()
955
956
            def post(self):
957
                return request.method.lower()
958
959
        app.add_url_rule("/a", view_func=Va.as_view("a"))
960
        app.add_url_rule("/b", view_func=Vb.as_view("b"))
961
        app.add_url_rule("/c", view_func=Vc.as_view("c"))
962
        app.add_url_rule("/d", view_func=Vd.as_view("d"))
963
964
        with hiro.Timeline().freeze() as timeline:
965
            with app.test_client() as cli:
966
                self.assertEqual(200, cli.get("/a").status_code)
967
                self.assertEqual(200, cli.get("/a").status_code)
968
                self.assertEqual(429, cli.get("/a").status_code)
969
                self.assertEqual(429, cli.post("/a").status_code)
970
                self.assertEqual(200, cli.get("/b").status_code)
971
                timeline.forward(1)
972
                self.assertEqual(200, cli.get("/b").status_code)
973
                timeline.forward(1)
974
                self.assertEqual(200, cli.get("/b").status_code)
975
                timeline.forward(1)
976
                self.assertEqual(429, cli.get("/b").status_code)
977
                self.assertEqual(200, cli.get("/c").status_code)
978
                self.assertEqual(429, cli.get("/c").status_code)
979
                self.assertEqual(200, cli.get("/d").status_code)
980
                self.assertEqual(429, cli.get("/d").status_code)
981
                self.assertEqual(200, cli.post("/d").status_code)
982
                self.assertEqual(200, cli.post("/d").status_code)
983
984
    def test_flask_restful_resource(self):
985
        app, limiter = self.build_app(default_limits=["1/hour"])
986
        api = RestfulApi(app)
987
988
        class Va(Resource):
989
            decorators = [limiter.limit("2/second")]
990
991
            def get(self):
992
                return request.method.lower()
993
994
            def post(self):
995
                return request.method.lower()
996
997
        class Vb(Resource):
998
            decorators = [limiter.limit("1/second, 3/minute")]
999
1000
            def get(self):
1001
                return request.method.lower()
1002
1003
        class Vc(Resource):
1004
            def get(self):
1005
                return request.method.lower()
1006
1007
        api.add_resource(Va, "/a")
1008
        api.add_resource(Vb, "/b")
1009
        api.add_resource(Vc, "/c")
1010
1011
        with hiro.Timeline().freeze() as timeline:
1012
            with app.test_client() as cli:
1013
                self.assertEqual(200, cli.get("/a").status_code)
1014
                self.assertEqual(200, cli.get("/a").status_code)
1015
                self.assertEqual(429, cli.get("/a").status_code)
1016
                self.assertEqual(429, cli.post("/a").status_code)
1017
                self.assertEqual(200, cli.get("/b").status_code)
1018
                timeline.forward(1)
1019
                self.assertEqual(200, cli.get("/b").status_code)
1020
                timeline.forward(1)
1021
                self.assertEqual(200, cli.get("/b").status_code)
1022
                timeline.forward(1)
1023
                self.assertEqual(429, cli.get("/b").status_code)
1024
                self.assertEqual(200, cli.get("/c").status_code)
1025
                self.assertEqual(429, cli.get("/c").status_code)
1026
1027
1028
class FlaskExtTests(FlaskLimiterTestCase):
1029
    def test_reset(self):
1030
        app, limiter = self.build_app({C.GLOBAL_LIMITS: "1 per day"})
1031
1032
        @app.route("/")
1033
        def null():
1034
            return "Hello Reset"
1035
1036
        with app.test_client() as cli:
1037
            cli.get("/")
1038
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
1039
            limiter.reset()
1040
            self.assertEqual("Hello Reset", cli.get("/").data.decode())
1041
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
1042
1043 View Code Duplication
    def test_combined_rate_limits(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1044
        app, limiter = self.build_app({
1045
            C.GLOBAL_LIMITS: "1 per hour; 10 per day"
1046
        })
1047
1048
        @app.route("/t1")
1049
        @limiter.limit("100 per hour;10/minute")
1050
        def t1():
1051
            return "t1"
1052
1053
        @app.route("/t2")
1054
        def t2():
1055
            return "t2"
1056
1057
        with hiro.Timeline().freeze() as timeline:
1058
            with app.test_client() as cli:
1059
                self.assertEqual(200, cli.get("/t1").status_code)
1060
                self.assertEqual(200, cli.get("/t2").status_code)
1061
                self.assertEqual(429, cli.get("/t2").status_code)
1062
1063
    def test_key_func(self):
1064
        app, limiter = self.build_app()
1065
1066
        @app.route("/t1")
1067
        @limiter.limit("100 per minute", lambda: "test")
1068
        def t1():
1069
            return "test"
1070
1071
        with hiro.Timeline().freeze() as timeline:
1072
            with app.test_client() as cli:
1073
                for i in range(0, 100):
1074
                    self.assertEqual(
1075
                        200,
1076
                        cli.get(
1077
                            "/t1", headers={
1078
                                "X_FORWARDED_FOR": "127.0.0.2"
1079
                            }
1080
                        ).status_code
1081
                    )
1082
                self.assertEqual(429, cli.get("/t1").status_code)
1083
1084 View Code Duplication
    def test_logging(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1085
        app = Flask(__name__)
1086
        limiter = Limiter(app, key_func=get_remote_address)
1087
        mock_handler = mock.Mock()
1088
        mock_handler.level = logging.INFO
1089
        limiter.logger.addHandler(mock_handler)
1090
1091
        @app.route("/t1")
1092
        @limiter.limit("1/minute")
1093
        def t1():
1094
            return "test"
1095
1096
        with app.test_client() as cli:
1097
            self.assertEqual(200, cli.get("/t1").status_code)
1098
            self.assertEqual(429, cli.get("/t1").status_code)
1099
        self.assertEqual(mock_handler.handle.call_count, 1)
1100
1101 View Code Duplication
    def test_reuse_logging(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1102
        app = Flask(__name__)
1103
        app_handler = mock.Mock()
1104
        app_handler.level = logging.INFO
1105
        app.logger.addHandler(app_handler)
1106
        limiter = Limiter(app, key_func=get_remote_address)
1107
        for handler in app.logger.handlers:
1108
            limiter.logger.addHandler(handler)
1109
1110
        @app.route("/t1")
1111
        @limiter.limit("1/minute")
1112
        def t1():
1113
            return "42"
1114
1115
        with app.test_client() as cli:
1116
            cli.get("/t1")
1117
            cli.get("/t1")
1118
1119
        self.assertEqual(app_handler.handle.call_count, 1)
1120
1121 View Code Duplication
    def test_disabled_flag(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1122
        app, limiter = self.build_app(
1123
            config={C.ENABLED: False}, default_limits=["1/minute"]
1124
        )
1125
1126
        @app.route("/t1")
1127
        def t1():
1128
            return "test"
1129
1130
        @app.route("/t2")
1131
        @limiter.limit("10 per minute")
1132
        def t2():
1133
            return "test"
1134
1135
        with app.test_client() as cli:
1136
            self.assertEqual(cli.get("/t1").status_code, 200)
1137
            self.assertEqual(cli.get("/t1").status_code, 200)
1138
            for i in range(0, 10):
1139
                self.assertEqual(cli.get("/t2").status_code, 200)
1140
            self.assertEqual(cli.get("/t2").status_code, 200)
1141
1142
1143
    def test_multiple_apps(self):
1144
        app1 = Flask(__name__)
1145
        app2 = Flask(__name__)
1146
1147
        limiter = Limiter(
1148
            default_limits=["1/second"], key_func=get_remote_address
1149
        )
1150
        limiter.init_app(app1)
1151
        limiter.init_app(app2)
1152
1153
        @app1.route("/ping")
1154
        def ping():
1155
            return "PONG"
1156
1157
        @app1.route("/slowping")
1158
        @limiter.limit("1/minute")
1159
        def slow_ping():
1160
            return "PONG"
1161
1162
        @app2.route("/ping")
1163
        @limiter.limit("2/second")
1164
        def ping_2():
1165
            return "PONG"
1166
1167
        @app2.route("/slowping")
1168
        @limiter.limit("2/minute")
1169
        def slow_ping_2():
1170
            return "PONG"
1171
1172
        with hiro.Timeline().freeze() as timeline:
1173
            with app1.test_client() as cli:
1174
                self.assertEqual(cli.get("/ping").status_code, 200)
1175
                self.assertEqual(cli.get("/ping").status_code, 429)
1176
                timeline.forward(1)
1177
                self.assertEqual(cli.get("/ping").status_code, 200)
1178
                self.assertEqual(cli.get("/slowping").status_code, 200)
1179
                timeline.forward(59)
1180
                self.assertEqual(cli.get("/slowping").status_code, 429)
1181
                timeline.forward(1)
1182
                self.assertEqual(cli.get("/slowping").status_code, 200)
1183
            with app2.test_client() as cli:
1184
                self.assertEqual(cli.get("/ping").status_code, 200)
1185
                self.assertEqual(cli.get("/ping").status_code, 200)
1186
                self.assertEqual(cli.get("/ping").status_code, 429)
1187
                timeline.forward(1)
1188
                self.assertEqual(cli.get("/ping").status_code, 200)
1189
                self.assertEqual(cli.get("/slowping").status_code, 200)
1190
                timeline.forward(59)
1191
                self.assertEqual(cli.get("/slowping").status_code, 200)
1192
                self.assertEqual(cli.get("/slowping").status_code, 429)
1193
                timeline.forward(1)
1194
                self.assertEqual(cli.get("/slowping").status_code, 200)
1195
1196
    def test_headers_no_breach(self):
1197
        app = Flask(__name__)
1198
        limiter = Limiter(
1199
            app,
1200
            default_limits=["10/minute"],
1201
            headers_enabled=True,
1202
            key_func=get_remote_address
1203
        )
1204
1205
        @app.route("/t1")
1206
        def t1():
1207
            return "test"
1208
1209
        @app.route("/t2")
1210
        @limiter.limit("2/second; 5 per minute; 10/hour")
1211
        def t2():
1212
            return "test"
1213
1214
        with hiro.Timeline().freeze():
1215
            with app.test_client() as cli:
1216
                resp = cli.get("/t1")
1217
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '10')
1218
                self.assertEqual(
1219
                    resp.headers.get('X-RateLimit-Remaining'), '9'
1220
                )
1221
                self.assertEqual(
1222
                    resp.headers.get('X-RateLimit-Reset'),
1223
                    str(int(time.time() + 61))
1224
                )
1225
                self.assertEqual(resp.headers.get('Retry-After'), str(60))
1226
                resp = cli.get("/t2")
1227
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '2')
1228
                self.assertEqual(
1229
                    resp.headers.get('X-RateLimit-Remaining'), '1'
1230
                )
1231
                self.assertEqual(
1232
                    resp.headers.get('X-RateLimit-Reset'),
1233
                    str(int(time.time() + 2))
1234
                )
1235
1236
                self.assertEqual(resp.headers.get('Retry-After'), str(1))
1237
1238 View Code Duplication
    def test_headers_breach(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1239
        app = Flask(__name__)
1240
        limiter = Limiter(
1241
            app,
1242
            default_limits=["10/minute"],
1243
            headers_enabled=True,
1244
            key_func=get_remote_address
1245
        )
1246
1247
        @app.route("/t1")
1248
        @limiter.limit("2/second; 10 per minute; 20/hour")
1249
        def t():
1250
            return "test"
1251
1252
        with hiro.Timeline().freeze() as timeline:
1253
            with app.test_client() as cli:
1254
                for i in range(11):
1255
                    resp = cli.get("/t1")
1256
                    timeline.forward(1)
1257
1258
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '10')
1259
                self.assertEqual(
1260
                    resp.headers.get('X-RateLimit-Remaining'), '0'
1261
                )
1262
                self.assertEqual(
1263
                    resp.headers.get('X-RateLimit-Reset'),
1264
                    str(int(time.time() + 50))
1265
                )
1266
                self.assertEqual(resp.headers.get('Retry-After'), str(int(50)))
1267
1268
    def test_retry_after(self):
1269
        app = Flask(__name__)
1270
        _ = Limiter(
1271
            app,
1272
            default_limits=["1/minute"],
1273
            headers_enabled=True,
1274
            key_func=get_remote_address
1275
        )
1276
1277
        @app.route("/t1")
1278
        def t():
1279
            return "test"
1280
1281
        with hiro.Timeline().freeze() as timeline:
1282
            with app.test_client() as cli:
1283
                resp = cli.get("/t1")
1284
                retry_after = int(resp.headers.get('Retry-After'))
1285
                self.assertTrue(retry_after > 0)
1286
                timeline.forward(retry_after)
1287
                resp = cli.get("/t1")
1288
                self.assertEqual(resp.status_code, 200)
1289
1290
    def test_retry_after_exists_seconds(self):
1291
        app = Flask(__name__)
1292
        _ = Limiter(
1293
            app,
1294
            default_limits=["1/minute"],
1295
            headers_enabled=True,
1296
            key_func=get_remote_address
1297
        )
1298
1299
        @app.route("/t1")
1300
        def t():
1301
            return "", 200, {'Retry-After': '1000000'}
1302
1303
        with app.test_client() as cli:
1304
            resp = cli.get("/t1")
1305
1306
            retry_after = int(resp.headers.get('Retry-After'))
1307
            self.assertTrue(retry_after > 1000)
1308
1309
    def test_retry_after_exists_rfc1123(self):
1310
        app = Flask(__name__)
1311
        _ = Limiter(
1312
            app,
1313
            default_limits=["1/minute"],
1314
            headers_enabled=True,
1315
            key_func=get_remote_address
1316
        )
1317
1318
        @app.route("/t1")
1319
        def t():
1320
            return "", 200, {'Retry-After': 'Sun, 06 Nov 2032 01:01:01 GMT'}
1321
1322
        with app.test_client() as cli:
1323
            resp = cli.get("/t1")
1324 View Code Duplication
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1325
            retry_after = int(resp.headers.get('Retry-After'))
1326
            self.assertTrue(retry_after > 1000)
1327
1328
    def test_custom_headers_from_setter(self):
1329
        app = Flask(__name__)
1330
        limiter = Limiter(
1331
            app,
1332
            default_limits=["10/minute"],
1333
            headers_enabled=True,
1334
            key_func=get_remote_address,
1335
            retry_after='http-date'
1336
        )
1337
        limiter._header_mapping[HEADERS.RESET] = 'X-Reset'
1338
        limiter._header_mapping[HEADERS.LIMIT] = 'X-Limit'
1339
        limiter._header_mapping[HEADERS.REMAINING] = 'X-Remaining'
1340
1341
        @app.route("/t1")
1342
        @limiter.limit("2/second; 10 per minute; 20/hour")
1343
        def t():
1344
            return "test"
1345
1346
        with hiro.Timeline().freeze(0) as timeline:
1347
            with app.test_client() as cli:
1348
                for i in range(11):
1349
                    resp = cli.get("/t1")
1350
                    timeline.forward(1)
1351
1352
                self.assertEqual(resp.headers.get('X-Limit'), '10')
1353
                self.assertEqual(resp.headers.get('X-Remaining'), '0')
1354
                self.assertEqual(
1355
                    resp.headers.get('X-Reset'), str(int(time.time() + 50))
1356
                )
1357
                self.assertEqual(
1358
                    resp.headers.get('Retry-After'),
1359
                    'Thu, 01 Jan 1970 00:01:01 GMT'
1360
                )
1361
1362
    def test_custom_headers_from_config(self):
1363
        app = Flask(__name__)
1364
        app.config.setdefault(C.HEADER_LIMIT, "X-Limit")
1365
        app.config.setdefault(C.HEADER_REMAINING, "X-Remaining")
1366
        app.config.setdefault(C.HEADER_RESET, "X-Reset")
1367
        limiter = Limiter(
1368
            app,
1369
            default_limits=["10/minute"],
1370
            headers_enabled=True,
1371
            key_func=get_remote_address
1372
        )
1373
1374
        @app.route("/t1")
1375
        @limiter.limit("2/second; 10 per minute; 20/hour")
1376
        def t():
1377
            return "test"
1378
1379
        with hiro.Timeline().freeze() as timeline:
1380
            with app.test_client() as cli:
1381
                for i in range(11):
1382
                    resp = cli.get("/t1")
1383
                    timeline.forward(1)
1384
1385
                self.assertEqual(resp.headers.get('X-Limit'), '10')
1386
                self.assertEqual(resp.headers.get('X-Remaining'), '0')
1387
                self.assertEqual(
1388
                    resp.headers.get('X-Reset'), str(int(time.time() + 50))
1389
                )
1390
1391 View Code Duplication
    def test_application_shared_limit(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1392
        app, limiter = self.build_app(application_limits=["2/minute"])
1393
1394
        @app.route("/t1")
1395
        def t1():
1396
            return "route1"
1397
1398
        @app.route("/t2")
1399
        def t2():
1400
            return "route2"
1401
1402
        with hiro.Timeline().freeze():
1403
            with app.test_client() as cli:
1404
                self.assertEqual(200, cli.get("/t1").status_code)
1405
                self.assertEqual(200, cli.get("/t2").status_code)
1406
                self.assertEqual(429, cli.get("/t1").status_code)
1407
1408 View Code Duplication
    def test_callable_default_limit(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1409
        app, limiter = self.build_app(default_limits=[lambda: "1/minute"])
1410
1411
        @app.route("/t1")
1412
        def t1():
1413
            return "t1"
1414
1415
        @app.route("/t2")
1416
        def t2():
1417
            return "t2"
1418
1419
        with hiro.Timeline().freeze():
1420
            with app.test_client() as cli:
1421
                self.assertEqual(cli.get("/t1").status_code, 200)
1422
                self.assertEqual(cli.get("/t2").status_code, 200)
1423
                self.assertEqual(cli.get("/t1").status_code, 429)
1424
                self.assertEqual(cli.get("/t2").status_code, 429)
1425
1426
    def test_callable_application_limit(self):
1427
1428
        app, limiter = self.build_app(application_limits=[lambda: "1/minute"])
1429
1430
        @app.route("/t1")
1431
        def t1():
1432
            return "t1"
1433
1434
        @app.route("/t2")
1435
        def t2():
1436
            return "t2"
1437
1438
        with hiro.Timeline().freeze():
1439
            with app.test_client() as cli:
1440
                self.assertEqual(cli.get("/t1").status_code, 200)
1441
                self.assertEqual(cli.get("/t2").status_code, 429)
1442
1443
    def test_no_auto_check(self):
1444
        app, limiter = self.build_app(auto_check=False)
1445
1446
        @limiter.limit("1/second", per_method=True)
1447
        @app.route("/", methods=["GET", "POST"])
1448
        def root():
1449
            return "root"
1450
1451
        with hiro.Timeline().freeze():
1452
            with app.test_client() as cli:
1453
                self.assertEqual(200, cli.get("/").status_code)
1454
                self.assertEqual(200, cli.get("/").status_code)
1455
1456
        # attach before_request to perform check
1457
        @app.before_request
1458
        def _():
1459
            limiter.check()
1460
1461
        with hiro.Timeline().freeze():
1462
            with app.test_client() as cli:
1463
                self.assertEqual(200, cli.get("/").status_code)
1464
                self.assertEqual(429, cli.get("/").status_code)
1465
1466
1467
    def test_custom_key_prefix(self):
1468
        app1, limiter1 = self.build_app(
1469
            key_prefix="moo", storage_uri="redis://localhost:6379"
1470
        )
1471
        app2, limiter2 = self.build_app({
1472
            C.KEY_PREFIX: "cow"
1473
        },
1474
                                        storage_uri="redis://localhost:6379")
1475
        app3, limiter3 = self.build_app(storage_uri="redis://localhost:6379")
1476
1477
        @app1.route("/test")
1478
        @limiter1.limit("1/day")
1479
        def t1():
1480
            return "app1 test"
1481
1482
        @app2.route("/test")
1483
        @limiter2.limit("1/day")
1484
        def t1():
1485
            return "app1 test"
1486
1487
        @app3.route("/test")
1488
        @limiter3.limit("1/day")
1489
        def t1():
1490
            return "app1 test"
1491
1492
        with app1.test_client() as cli:
1493
            resp = cli.get("/test")
1494
            self.assertEqual(200, resp.status_code)
1495
            resp = cli.get("/test")
1496
            self.assertEqual(429, resp.status_code)
1497
        with app2.test_client() as cli:
1498
            resp = cli.get("/test")
1499
            self.assertEqual(200, resp.status_code)
1500
            resp = cli.get("/test")
1501
            self.assertEqual(429, resp.status_code)
1502
        with app3.test_client() as cli:
1503
            resp = cli.get("/test")
1504
            self.assertEqual(200, resp.status_code)
1505
            resp = cli.get("/test")
1506
            self.assertEqual(429, resp.status_code)
1507