Completed
Push — master ( 7140d7...45c1cc )
by Ali-Akber
32s
created

DecoratorTests.something_else()   A

Complexity

Conditions 2

Size

Total Lines 5

Duplication

Lines 5
Ratio 100 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 2
c 1
b 0
f 0
dl 5
loc 5
rs 9.4285
1
"""
2
3
"""
4
import json
5
import logging
6
import time
7
import unittest
8
from functools import wraps
9
10
import functools
11
import hiro
12
import mock
13
import redis
14
import datetime
15
from flask import Flask, Blueprint, request, current_app, make_response, g
16
from flask_restful import Resource, Api as RestfulApi
17
from flask.views import View, MethodView
18
from limits.errors import ConfigurationError
19
from limits.storage import MemcachedStorage
20
from limits.strategies import MovingWindowRateLimiter
21
22
from flask_limiter.extension import C, Limiter, HEADERS
23
from flask_limiter.util import get_remote_address, get_ipaddr
24
from tests import FlaskLimiterTestCase
25
26
27
class ConfigurationTests(FlaskLimiterTestCase):
28
    def test_invalid_strategy(self):
29
        app = Flask(__name__)
30
        app.config.setdefault(C.STRATEGY, "fubar")
31
        self.assertRaises(
32
            ConfigurationError, Limiter, app, key_func=get_remote_address
33
        )
34
35
    def test_invalid_storage_string(self):
36
        app = Flask(__name__)
37
        app.config.setdefault(C.STORAGE_URL, "fubar://localhost:1234")
38
        self.assertRaises(
39
            ConfigurationError, Limiter, app, key_func=get_remote_address
40
        )
41
42
    def test_constructor_arguments_over_config(self):
43
        app = Flask(__name__)
44
        app.config.setdefault(C.STRATEGY, "fixed-window-elastic-expiry")
45
        limiter = Limiter(
46
            strategy='moving-window', key_func=get_remote_address
47
        )
48
        limiter.init_app(app)
49
        app.config.setdefault(C.STORAGE_URL, "redis://localhost:6379")
50
        self.assertEqual(type(limiter._limiter), MovingWindowRateLimiter)
51
        limiter = Limiter(
52
            storage_uri='memcached://localhost:11211',
53
            key_func=get_remote_address
54
        )
55
        limiter.init_app(app)
56 View Code Duplication
        self.assertEqual(type(limiter._storage), MemcachedStorage)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
57
58
59
class ErrorHandlingTests(FlaskLimiterTestCase):
60
    def test_error_message(self):
61
        app, limiter = self.build_app({C.GLOBAL_LIMITS: "1 per day"})
62
63
        @app.route("/")
64
        def null():
65
            return ""
66
67
        with app.test_client() as cli:
68
            cli.get("/")
69
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
70
71
            @app.errorhandler(429)
72
            def ratelimit_handler(e):
73
                return make_response(
74
                    '{"error" : "rate limit %s"}' % str(e.description), 429
75
                )
76
77
            self.assertEqual({
78
                'error': 'rate limit 1 per 1 day'
79
            }, json.loads(cli.get("/").data.decode()))
80
81
    def test_custom_error_message(self):
82
        app, limiter = self.build_app()
83
84
        @app.errorhandler(429)
85
        def ratelimit_handler(e):
86
            return make_response(e.description, 429)
87
88
        l1 = lambda: "1/second"
89
        e1 = lambda: "dos"
90
91
        @limiter.limit("1/second", error_message="uno")
92
        @app.route("/t1")
93
        def t1():
94
            return "1"
95
96
        @limiter.limit(l1, error_message=e1)
97
        @app.route("/t2")
98
        def t2():
99
            return "2"
100
101
        s1 = limiter.shared_limit(
102
            "1/second", scope='error_message', error_message="tres"
103
        )
104
105
        @app.route("/t3")
106
        @s1
107
        def t3():
108
            return "3"
109
110
        with hiro.Timeline().freeze():
111
            with app.test_client() as cli:
112 View Code Duplication
                cli.get("/t1")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
113
                resp = cli.get("/t1")
114
                self.assertEqual(429, resp.status_code)
115
                self.assertEqual(resp.data, b'uno')
116
                cli.get("/t2")
117
                resp = cli.get("/t2")
118
                self.assertEqual(429, resp.status_code)
119
                self.assertEqual(resp.data, b'dos')
120
                cli.get("/t3")
121
                resp = cli.get("/t3")
122
                self.assertEqual(429, resp.status_code)
123
                self.assertEqual(resp.data, b'tres')
124
125
    def test_swallow_error(self):
126
        app, limiter = self.build_app({
127
            C.GLOBAL_LIMITS: "1 per day",
128
            C.SWALLOW_ERRORS: True
129
        })
130
131
        @app.route("/")
132
        def null():
133
            return "ok"
134
135
        with app.test_client() as cli:
136
            with mock.patch(
137
                "limits.strategies.FixedWindowRateLimiter.hit"
138
            ) as hit:
139
140
                def raiser(*a, **k):
141
                    raise Exception
142
143
                hit.side_effect = raiser
144
                self.assertTrue("ok" in cli.get("/").data.decode())
145
146
    def test_no_swallow_error(self):
147
        app, limiter = self.build_app({
148
            C.GLOBAL_LIMITS: "1 per day",
149
        })
150
151
        @app.route("/")
152
        def null():
153
            return "ok"
154
155
        @app.errorhandler(500)
156
        def e500(e):
157
            return str(e), 500
158
159
        with app.test_client() as cli:
160
            with mock.patch(
161
                "limits.strategies.FixedWindowRateLimiter.hit"
162
            ) as hit:
163
164
                def raiser(*a, **k):
165
                    raise Exception("underlying")
166
167
                hit.side_effect = raiser
168
                self.assertEqual(500, cli.get("/").status_code)
169
                self.assertEqual("underlying", cli.get("/").data.decode())
170
171
    def test_fallback_to_memory_config(self):
172
        _, limiter = self.build_app(
173
            config={C.ENABLED: True},
174
            default_limits=["5/minute"],
175
            storage_uri="redis://localhost:6379",
176
            in_memory_fallback=["1/minute"]
177
        )
178
        self.assertEqual(len(limiter._in_memory_fallback), 1)
179
180
        _, limiter = self.build_app(
181
            config={C.ENABLED: True,
182
                    C.IN_MEMORY_FALLBACK: "1/minute"},
183
            default_limits=["5/minute"],
184
            storage_uri="redis://localhost:6379",
185
        )
186
        self.assertEqual(len(limiter._in_memory_fallback), 1)
187
188
    def test_fallback_to_memory_backoff_check(self):
189
        app, limiter = self.build_app(
190
            config={C.ENABLED: True},
191
            default_limits=["5/minute"],
192
            storage_uri="redis://localhost:6379",
193
            in_memory_fallback=["1/minute"]
194
        )
195
196
        @app.route("/t1")
197
        def t1():
198
            return "test"
199
200
        with app.test_client() as cli:
201
202
            def raiser(*a):
203
                raise Exception("redis dead")
204
205
            with hiro.Timeline() as timeline:
206
                with mock.patch(
207
                    "redis.client.Redis.execute_command"
208
                ) as exec_command:
209
                    exec_command.side_effect = raiser
210 View Code Duplication
                    self.assertEqual(cli.get("/t1").status_code, 200)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
211
                    self.assertEqual(cli.get("/t1").status_code, 429)
212
                    timeline.forward(1)
213
                    self.assertEqual(cli.get("/t1").status_code, 429)
214
                    timeline.forward(2)
215
                    self.assertEqual(cli.get("/t1").status_code, 429)
216
                    timeline.forward(4)
217
                    self.assertEqual(cli.get("/t1").status_code, 429)
218
                    timeline.forward(8)
219
                    self.assertEqual(cli.get("/t1").status_code, 429)
220
                    timeline.forward(16)
221
                    self.assertEqual(cli.get("/t1").status_code, 429)
222
                    timeline.forward(32)
223
                    self.assertEqual(cli.get("/t1").status_code, 200)
224
                # redis back to normal, but exponential backoff will only
225
                # result in it being marked after pow(2,0) seconds and next
226
                # check
227 View Code Duplication
                self.assertEqual(cli.get("/t1").status_code, 429)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
228
                timeline.forward(2)
229
                self.assertEqual(cli.get("/t1").status_code, 200)
230
                self.assertEqual(cli.get("/t1").status_code, 200)
231
                self.assertEqual(cli.get("/t1").status_code, 200)
232
                self.assertEqual(cli.get("/t1").status_code, 200)
233
                self.assertEqual(cli.get("/t1").status_code, 200)
234
                self.assertEqual(cli.get("/t1").status_code, 429)
235
236
    def test_fallback_to_memory(self):
237
        app, limiter = self.build_app(
238
            config={C.ENABLED: True},
239
            default_limits=["5/minute"],
240
            storage_uri="redis://localhost:6379",
241
            in_memory_fallback=["1/minute"]
242
        )
243
244
        @app.route("/t1")
245
        def t1():
246
            return "test"
247
248
        @app.route("/t2")
249
        @limiter.limit("3 per minute")
250
        def t2():
251
            return "test"
252
253
        with app.test_client() as cli:
254
            self.assertEqual(cli.get("/t1").status_code, 200)
255
            self.assertEqual(cli.get("/t1").status_code, 200)
256
            self.assertEqual(cli.get("/t1").status_code, 200)
257
            self.assertEqual(cli.get("/t1").status_code, 200)
258
            self.assertEqual(cli.get("/t1").status_code, 200)
259
            self.assertEqual(cli.get("/t1").status_code, 429)
260
            self.assertEqual(cli.get("/t2").status_code, 200)
261
            self.assertEqual(cli.get("/t2").status_code, 200)
262
            self.assertEqual(cli.get("/t2").status_code, 200)
263
            self.assertEqual(cli.get("/t2").status_code, 429)
264
265
            def raiser(*a):
266
                raise Exception("redis dead")
267
268
            with mock.patch(
269
                "redis.client.Redis.execute_command"
270
            ) as exec_command:
271
                exec_command.side_effect = raiser
272
                self.assertEqual(cli.get("/t1").status_code, 200)
273
                self.assertEqual(cli.get("/t1").status_code, 429)
274
                self.assertEqual(cli.get("/t2").status_code, 200)
275
                self.assertEqual(cli.get("/t2").status_code, 429)
276
            # redis back to normal, go back to regular limits
277
            with hiro.Timeline() as timeline:
278
                timeline.forward(2)
279
                limiter._storage.storage.flushall()
280
                self.assertEqual(cli.get("/t2").status_code, 200)
281
                self.assertEqual(cli.get("/t2").status_code, 200)
282
                self.assertEqual(cli.get("/t2").status_code, 200)
283
                self.assertEqual(cli.get("/t2").status_code, 429)
284
285
286
class DecoratorTests(FlaskLimiterTestCase):
287
    def test_multiple_decorators(self):
288
        app, limiter = self.build_app(key_func=get_ipaddr)
289
290
        @app.route("/t1")
291
        @limiter.limit(
292
            "100 per minute", lambda: "test"
293
        )  # effectively becomes a limit for all users
294
        @limiter.limit("50/minute")  # per ip as per default key_func
295
        def t1():
296
            return "test"
297
298
        with hiro.Timeline().freeze() as timeline:
299
            with app.test_client() as cli:
300
                for i in range(0, 100):
301
                    self.assertEqual(
302
                        200 if i < 50 else 429,
303
                        cli.get(
304
                            "/t1", headers={
305
                                "X_FORWARDED_FOR": "127.0.0.2"
306
                            }
307
                        ).status_code
308
                    )
309
                for i in range(50):
310
                    self.assertEqual(200, cli.get("/t1").status_code)
311
                self.assertEqual(429, cli.get("/t1").status_code)
312
                self.assertEqual(
313
                    429,
314
                    cli.get("/t1", headers={
315
                        "X_FORWARDED_FOR": "127.0.0.3"
316
                    }).status_code
317 View Code Duplication
                )
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
318
319
    def test_exempt_routes(self):
320
        app, limiter = self.build_app(default_limits=["1/minute"])
321
322
        @app.route("/t1")
323
        def t1():
324
            return "test"
325
326
        @app.route("/t2")
327
        @limiter.exempt
328
        def t2():
329
            return "test"
330
331
        with app.test_client() as cli:
332
            self.assertEqual(cli.get("/t1").status_code, 200)
333
            self.assertEqual(cli.get("/t1").status_code, 429)
334
            self.assertEqual(cli.get("/t2").status_code, 200)
335
            self.assertEqual(cli.get("/t2").status_code, 200)
336
337
    def test_decorated_dynamic_limits(self):
338
        app, limiter = self.build_app({
339
            "X": "2 per second"
340
        },
341
                                      default_limits=["1/second"])
342
343
        def request_context_limit():
344
            limits = {
345
                "127.0.0.1": "10 per minute",
346
                "127.0.0.2": "1 per minute"
347
            }
348
            remote_addr = (request.access_route and request.access_route[0]
349
                           ) or request.remote_addr or '127.0.0.1'
350
            limit = limits.setdefault(remote_addr, '1 per minute')
351
            return limit
352
353
        @app.route("/t1")
354
        @limiter.limit("20/day")
355
        @limiter.limit(lambda: current_app.config.get("X"))
356
        @limiter.limit(request_context_limit)
357
        def t1():
358
            return "42"
359
360
        @app.route("/t2")
361
        @limiter.limit(lambda: current_app.config.get("X"))
362
        def t2():
363
            return "42"
364
365
        R1 = {"X_FORWARDED_FOR": "127.0.0.1, 127.0.0.0"}
366
        R2 = {"X_FORWARDED_FOR": "127.0.0.2"}
367
368
        with app.test_client() as cli:
369
            with hiro.Timeline().freeze() as timeline:
370
                for i in range(0, 10):
371
                    self.assertEqual(
372
                        cli.get("/t1", headers=R1).status_code, 200
373
                    )
374
                    timeline.forward(1)
375
                self.assertEqual(cli.get("/t1", headers=R1).status_code, 429)
376
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 200)
377
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 429)
378
                timeline.forward(60)
379
                self.assertEqual(cli.get("/t1", headers=R2).status_code, 200)
380
                self.assertEqual(cli.get("/t2").status_code, 200)
381
                self.assertEqual(cli.get("/t2").status_code, 200)
382
                self.assertEqual(cli.get("/t2").status_code, 429)
383
                timeline.forward(1)
384
                self.assertEqual(cli.get("/t2").status_code, 200)
385
386
    def test_invalid_decorated_dynamic_limits(self):
387
        app = Flask(__name__)
388
        app.config.setdefault("X", "2 per sec")
389
        limiter = Limiter(
390
            app, default_limits=["1/second"], key_func=get_remote_address
391
        )
392
        mock_handler = mock.Mock()
393
        mock_handler.level = logging.INFO
394
        limiter.logger.addHandler(mock_handler)
395
396
        @app.route("/t1")
397
        @limiter.limit(lambda: current_app.config.get("X"))
398
        def t1():
399
            return "42"
400
401
        with app.test_client() as cli:
402
            with hiro.Timeline().freeze() as timeline:
403
                self.assertEqual(cli.get("/t1").status_code, 200)
404
                self.assertEqual(cli.get("/t1").status_code, 429)
405
        # 2 for invalid limit, 1 for warning.
406
        self.assertEqual(mock_handler.handle.call_count, 3)
407
        self.assertTrue(
408
            "failed to load ratelimit" in mock_handler.handle.call_args_list[0]
409
            [0][0].msg
410
        )
411
        self.assertTrue(
412
            "failed to load ratelimit" in mock_handler.handle.call_args_list[1]
413
            [0][0].msg
414
        )
415
        self.assertTrue(
416
            "exceeded at endpoint" in mock_handler.handle.call_args_list[2][0]
417
            [0].msg
418
        )
419
420
    def test_invalid_decorated_static_limits(self):
421
        app = Flask(__name__)
422
        limiter = Limiter(
423
            app, default_limits=["1/second"], key_func=get_remote_address
424
        )
425
        mock_handler = mock.Mock()
426
        mock_handler.level = logging.INFO
427
        limiter.logger.addHandler(mock_handler)
428
429
        @app.route("/t1")
430
        @limiter.limit("2/sec")
431
        def t1():
432
            return "42"
433
434
        with app.test_client() as cli:
435
            with hiro.Timeline().freeze() as timeline:
436
                self.assertEqual(cli.get("/t1").status_code, 200)
437
                self.assertEqual(cli.get("/t1").status_code, 429)
438
        self.assertTrue(
439
            "failed to configure" in mock_handler.handle.call_args_list[0][0]
440
            [0].msg
441
        )
442
        self.assertTrue(
443
            "exceeded at endpoint" in mock_handler.handle.call_args_list[1][0]
444
            [0].msg
445
        )
446
447
    def test_named_shared_limit(self):
448
        app, limiter = self.build_app()
449
        shared_limit_a = limiter.shared_limit("1/minute", scope='a')
450
        shared_limit_b = limiter.shared_limit("1/minute", scope='b')
451
452
        @app.route("/t1")
453
        @shared_limit_a
454
        def route1():
455
            return "route1"
456
457
        @app.route("/t2")
458
        @shared_limit_a
459
        def route2():
460
            return "route2"
461
462
        @app.route("/t3")
463
        @shared_limit_b
464
        def route3():
465
            return "route3"
466
467
        with hiro.Timeline().freeze() as timeline:
468
            with app.test_client() as cli:
469
                self.assertEqual(200, cli.get("/t1").status_code)
470
                self.assertEqual(200, cli.get("/t3").status_code)
471
                self.assertEqual(429, cli.get("/t2").status_code)
472
473
    def test_dynamic_shared_limit(self):
474
        app, limiter = self.build_app()
475
        fn_a = mock.Mock()
476
        fn_b = mock.Mock()
477
        fn_a.return_value = "foo"
478
        fn_b.return_value = "bar"
479
480
        dy_limit_a = limiter.shared_limit("1/minute", scope=fn_a)
481
        dy_limit_b = limiter.shared_limit("1/minute", scope=fn_b)
482
483
        @app.route("/t1")
484
        @dy_limit_a
485
        def t1():
486
            return "route1"
487
488
        @app.route("/t2")
489
        @dy_limit_a
490
        def t2():
491
            return "route2"
492
493
        @app.route("/t3")
494
        @dy_limit_b
495
        def t3():
496
            return "route3"
497
498
        with hiro.Timeline().freeze():
499
            with app.test_client() as cli:
500
                self.assertEqual(200, cli.get("/t1").status_code)
501
                self.assertEqual(200, cli.get("/t3").status_code)
502
                self.assertEqual(429, cli.get("/t2").status_code)
503
                self.assertEqual(429, cli.get("/t3").status_code)
504
                self.assertEqual(2, fn_a.call_count)
505
                self.assertEqual(2, fn_b.call_count)
506
                fn_b.assert_called_with("t3")
507 View Code Duplication
                fn_a.assert_has_calls([mock.call("t1"), mock.call("t2")])
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
508
509
    def test_conditional_limits(self):
510
        """Test that the conditional activation of the limits work."""
511
        app = Flask(__name__)
512
        limiter = Limiter(app, key_func=get_remote_address)
513
514
        @app.route("/limited")
515
        @limiter.limit("1 per day")
516
        def limited_route():
517
            return "passed"
518
519
        @app.route("/unlimited")
520
        @limiter.limit("1 per day", exempt_when=lambda: True)
521
        def never_limited_route():
522
            return "should always pass"
523
524
        is_exempt = False
525
526
        @app.route("/conditional")
527
        @limiter.limit("1 per day", exempt_when=lambda: is_exempt)
528
        def conditionally_limited_route():
529
            return "conditional"
530 View Code Duplication
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
531
        with app.test_client() as cli:
532
            self.assertEqual(cli.get("/limited").status_code, 200)
533
            self.assertEqual(cli.get("/limited").status_code, 429)
534
535
            self.assertEqual(cli.get("/unlimited").status_code, 200)
536
            self.assertEqual(cli.get("/unlimited").status_code, 200)
537
538
            self.assertEqual(cli.get("/conditional").status_code, 200)
539
            self.assertEqual(cli.get("/conditional").status_code, 429)
540
            is_exempt = True
541
            self.assertEqual(cli.get("/conditional").status_code, 200)
542
            is_exempt = False
543
            self.assertEqual(cli.get("/conditional").status_code, 429)
544
545
    def test_conditional_shared_limits(self):
546
        """Test that conditional shared limits work."""
547
        app = Flask(__name__)
548
        limiter = Limiter(app, key_func=get_remote_address)
549 View Code Duplication
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
550
        @app.route("/limited")
551
        @limiter.shared_limit("1 per day", "test_scope")
552
        def limited_route():
553
            return "passed"
554
555
        @app.route("/unlimited")
556
        @limiter.shared_limit(
557
            "1 per day", "test_scope", exempt_when=lambda: True
558
        )
559
        def never_limited_route():
560
            return "should always pass"
561
562
        is_exempt = False
563
564
        @app.route("/conditional")
565
        @limiter.shared_limit(
566
            "1 per day", "test_scope", exempt_when=lambda: is_exempt
567
        )
568
        def conditionally_limited_route():
569
            return "conditional"
570
571 View Code Duplication
        with app.test_client() as cli:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
572
            self.assertEqual(cli.get("/unlimited").status_code, 200)
573
            self.assertEqual(cli.get("/unlimited").status_code, 200)
574
575
            self.assertEqual(cli.get("/limited").status_code, 200)
576
            self.assertEqual(cli.get("/limited").status_code, 429)
577
578
            self.assertEqual(cli.get("/conditional").status_code, 429)
579
            is_exempt = True
580
            self.assertEqual(cli.get("/conditional").status_code, 200)
581
            is_exempt = False
582
            self.assertEqual(cli.get("/conditional").status_code, 429)
583
584
    def test_whitelisting(self):
585
586
        app = Flask(__name__)
587
        limiter = Limiter(
588
            app,
589
            default_limits=["1/minute"],
590
            headers_enabled=True,
591
            key_func=get_remote_address
592
        )
593
594
        @app.route("/")
595
        def t():
596
            return "test"
597
598
        @limiter.request_filter
599
        def w():
600
            if request.headers.get("internal", None) == "true":
601
                return True
602
            return False
603
604
        with hiro.Timeline().freeze() as timeline:
605
            with app.test_client() as cli:
606
                self.assertEqual(cli.get("/").status_code, 200)
607
                self.assertEqual(cli.get("/").status_code, 429)
608
                timeline.forward(60)
609
                self.assertEqual(cli.get("/").status_code, 200)
610
611
                for i in range(0, 10):
612
                    self.assertEqual(
613
                        cli.get("/", headers={
614
                            "internal": "true"
615
                        }).status_code, 200
616 View Code Duplication
                    )
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
617
618
    def test_separate_method_limits(self):
619
        app, limiter = self.build_app()
620
621
        @limiter.limit("1/second", per_method=True)
622
        @app.route("/", methods=["GET", "POST"])
623
        def root():
624
            return "root"
625
626
        with hiro.Timeline():
627
            with app.test_client() as cli:
628
                self.assertEqual(200, cli.get("/").status_code)
629
                self.assertEqual(429, cli.get("/").status_code)
630
                self.assertEqual(200, cli.post("/").status_code)
631 View Code Duplication
                self.assertEqual(429, cli.post("/").status_code)
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
632
633
    def test_explicit_method_limits(self):
634
        app, limiter = self.build_app()
635
636
        @limiter.limit("1/second", methods=["GET"])
637
        @app.route("/", methods=["GET", "POST"])
638
        def root():
639
            return "root"
640
641
        with hiro.Timeline():
642
            with app.test_client() as cli:
643
                self.assertEqual(200, cli.get("/").status_code)
644
                self.assertEqual(429, cli.get("/").status_code)
645
                self.assertEqual(200, cli.post("/").status_code)
646
                self.assertEqual(200, cli.post("/").status_code)
647
648 View Code Duplication
    def test_decorated_limit_immediate(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
649
        app, limiter = self.build_app(default_limits=["1/minute"])
650
651
        def append_info(fn):
652
            @wraps(fn)
653
            def __inner(*args, **kwargs):
654
                g.rate_limit = "2/minute"
655
                return fn(*args, **kwargs)
656
            return __inner
657
658
        @app.route("/", methods=["GET", "POST"])
659
        @append_info
660
        @limiter.limit(lambda: g.rate_limit, per_method=True)
661
        def root():
662
            return "root"
663
664
        with hiro.Timeline().freeze():
665
            with app.test_client() as cli:
666
                self.assertEqual(200, cli.get("/").status_code)
667
                self.assertEqual(200, cli.get("/").status_code)
668
                self.assertEqual(429, cli.get("/").status_code)
669
670
    def test_decorated_shared_limit_immediate(self):
671
672
        app, limiter = self.build_app(default_limits=['1/minute'])
673
        shared = limiter.shared_limit(lambda: g.rate_limit, 'shared')
674
        def append_info(fn):
675
            @wraps(fn)
676
            def __inner(*args, **kwargs):
677
                g.rate_limit = "2/minute"
678
                return fn(*args, **kwargs)
679
            return __inner
680
681
        @app.route("/", methods=["GET", "POST"])
682
        @append_info
683
        @shared
684
        def root():
685
            return "root"
686
687
        @app.route("/other", methods=["GET", "POST"])
688
        def other():
689
            return "other"
690
691
        with hiro.Timeline().freeze():
692
            with app.test_client() as cli:
693
                self.assertEqual(200, cli.get("/other").status_code)
694
                self.assertEqual(429, cli.get("/other").status_code)
695
                self.assertEqual(200, cli.get("/").status_code)
696
                self.assertEqual(200, cli.get("/").status_code)
697
                self.assertEqual(429, cli.get("/").status_code)
698
699
    def test_backward_compatibility_with_incorrect_ordering(self):
700
        app, limiter = self.build_app()
701 View Code Duplication
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
702
        def something_else(fn):
703
            @functools.wraps(fn)
704
            def __inner(*args, **kwargs):
705
                return fn(*args, **kwargs)
706
            return __inner
707
708
        @limiter.limit("1/second")
709
        @app.route("/t1", methods=["GET", "POST"])
710
        def root():
711
            return "t1"
712
713
        @limiter.limit("1/second")
714
        @app.route("/t2", methods=["GET", "POST"])
715
        @something_else
716
        def t2():
717
            return "t2"
718
719
        @limiter.limit("2/second")
720
        @limiter.limit("1/second")
721
        @app.route("/t3", methods=["GET", "POST"])
722
        def t3():
723
            return "t3"
724
725
726
        with hiro.Timeline().freeze():
727
            with app.test_client() as cli:
728
                self.assertEqual(200, cli.get("/t1").status_code)
729
                self.assertEqual(429, cli.get("/t1").status_code)
730
                self.assertEqual(200, cli.get("/t2").status_code)
731
                self.assertEqual(429, cli.get("/t2").status_code)
732
                self.assertEqual(200, cli.get("/t3").status_code)
733
                self.assertEqual(429, cli.get("/t3").status_code)
734
735
736
class BlueprintTests(FlaskLimiterTestCase):
737
    def test_blueprint(self):
738
        app, limiter = self.build_app(default_limits=["1/minute"])
739
        bp = Blueprint("main", __name__)
740
741
        @bp.route("/t1")
742
        def t1():
743
            return "test"
744
745
        @bp.route("/t2")
746
        @limiter.limit("10 per minute")
747
        def t2():
748
            return "test"
749
750
        app.register_blueprint(bp)
751
752
        with app.test_client() as cli:
753
            self.assertEqual(cli.get("/t1").status_code, 200)
754
            self.assertEqual(cli.get("/t1").status_code, 429)
755
            for i in range(0, 10):
756
                self.assertEqual(cli.get("/t2").status_code, 200)
757
            self.assertEqual(cli.get("/t2").status_code, 429)
758
759
    def test_register_blueprint(self):
760
        app, limiter = self.build_app(default_limits=["1/minute"])
761
        bp_1 = Blueprint("bp1", __name__)
762
        bp_2 = Blueprint("bp2", __name__)
763
        bp_3 = Blueprint("bp3", __name__)
764
        bp_4 = Blueprint("bp4", __name__)
765
766
        @bp_1.route("/t1")
767
        def t1():
768
            return "test"
769
770
        @bp_1.route("/t2")
771
        def t2():
772
            return "test"
773
774 View Code Duplication
        @bp_2.route("/t3")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
775
        def t3():
776
            return "test"
777
778
        @bp_3.route("/t4")
779
        def t4():
780
            return "test"
781
782
        @bp_4.route("/t5")
783
        def t4():
784
            return "test"
785
786
        def dy_limit():
787
            return "1/second"
788
789
        app.register_blueprint(bp_1)
790
        app.register_blueprint(bp_2)
791
        app.register_blueprint(bp_3)
792
        app.register_blueprint(bp_4)
793
794
        limiter.limit("1/second")(bp_1)
795
        limiter.exempt(bp_3)
796
        limiter.limit(dy_limit)(bp_4)
797
798
        with hiro.Timeline().freeze() as timeline:
799
            with app.test_client() as cli:
800
                self.assertEqual(cli.get("/t1").status_code, 200)
801
                self.assertEqual(cli.get("/t1").status_code, 429)
802
                timeline.forward(1)
803
                self.assertEqual(cli.get("/t1").status_code, 200)
804
                self.assertEqual(cli.get("/t2").status_code, 200)
805
                self.assertEqual(cli.get("/t2").status_code, 429)
806
                timeline.forward(1)
807
                self.assertEqual(cli.get("/t2").status_code, 200)
808
809
                self.assertEqual(cli.get("/t3").status_code, 200)
810
                for i in range(0, 10):
811
                    timeline.forward(1)
812
                    self.assertEqual(cli.get("/t3").status_code, 429)
813
814
                for i in range(0, 10):
815
                    self.assertEqual(cli.get("/t4").status_code, 200)
816
817
                self.assertEqual(cli.get("/t5").status_code, 200)
818
                self.assertEqual(cli.get("/t5").status_code, 429)
819
820
    def test_invalid_decorated_static_limit_blueprint(self):
821
        app = Flask(__name__)
822
        limiter = Limiter(
823
            app, default_limits=["1/second"], key_func=get_remote_address
824
        )
825
        mock_handler = mock.Mock()
826
        mock_handler.level = logging.INFO
827
        limiter.logger.addHandler(mock_handler)
828
        bp = Blueprint("bp1", __name__)
829
830
        @bp.route("/t1")
831
        def t1():
832
            return "42"
833
834
        limiter.limit("2/sec")(bp)
835
        app.register_blueprint(bp)
836
837
        with app.test_client() as cli:
838
            with hiro.Timeline().freeze() as timeline:
839
                self.assertEqual(cli.get("/t1").status_code, 200)
840
                self.assertEqual(cli.get("/t1").status_code, 429)
841
        self.assertTrue(
842
            "failed to configure" in mock_handler.handle.call_args_list[0][0]
843
            [0].msg
844
        )
845
        self.assertTrue(
846
            "exceeded at endpoint" in mock_handler.handle.call_args_list[1][0]
847
            [0].msg
848
        )
849
850
    def test_invalid_decorated_dynamic_limits_blueprint(self):
851
        app = Flask(__name__)
852
        app.config.setdefault("X", "2 per sec")
853
        limiter = Limiter(
854
            app, default_limits=["1/second"], key_func=get_remote_address
855
        )
856
        mock_handler = mock.Mock()
857
        mock_handler.level = logging.INFO
858
        limiter.logger.addHandler(mock_handler)
859
        bp = Blueprint("bp1", __name__)
860
861
        @bp.route("/t1")
862
        def t1():
863
            return "42"
864
865
        limiter.limit(lambda: current_app.config.get("X"))(bp)
866
        app.register_blueprint(bp)
867
868
        with app.test_client() as cli:
869
            with hiro.Timeline().freeze() as timeline:
870
                self.assertEqual(cli.get("/t1").status_code, 200)
871
                self.assertEqual(cli.get("/t1").status_code, 429)
872
        self.assertEqual(mock_handler.handle.call_count, 3)
873
        self.assertTrue(
874
            "failed to load ratelimit" in mock_handler.handle.call_args_list[0]
875
            [0][0].msg
876
        )
877
        self.assertTrue(
878
            "failed to load ratelimit" in mock_handler.handle.call_args_list[1]
879
            [0][0].msg
880
        )
881
        self.assertTrue(
882
            "exceeded at endpoint" in mock_handler.handle.call_args_list[2][0]
883
            [0].msg
884
        )
885
886
class ViewsTests(FlaskLimiterTestCase):
887
    def test_pluggable_views(self):
888
        app, limiter = self.build_app(default_limits=["1/hour"])
889
890
        class Va(View):
891
            methods = ['GET', 'POST']
892
            decorators = [limiter.limit("2/second")]
893
894
            def dispatch_request(self):
895
                return request.method.lower()
896
897
        class Vb(View):
898
            methods = ['GET']
899
            decorators = [limiter.limit("1/second, 3/minute")]
900
901
            def dispatch_request(self):
902
                return request.method.lower()
903
904
        class Vc(View):
905
            methods = ['GET']
906
907
            def dispatch_request(self):
908
                return request.method.lower()
909
910
        app.add_url_rule("/a", view_func=Va.as_view("a"))
911
        app.add_url_rule("/b", view_func=Vb.as_view("b"))
912
        app.add_url_rule("/c", view_func=Vc.as_view("c"))
913
        with hiro.Timeline().freeze() as timeline:
914
            with app.test_client() as cli:
915
                self.assertEqual(200, cli.get("/a").status_code)
916
                self.assertEqual(200, cli.get("/a").status_code)
917
                self.assertEqual(429, cli.post("/a").status_code)
918
                self.assertEqual(200, cli.get("/b").status_code)
919
                timeline.forward(1)
920
                self.assertEqual(200, cli.get("/b").status_code)
921
                timeline.forward(1)
922
                self.assertEqual(200, cli.get("/b").status_code)
923
                timeline.forward(1)
924
                self.assertEqual(429, cli.get("/b").status_code)
925
                self.assertEqual(200, cli.get("/c").status_code)
926
                self.assertEqual(429, cli.get("/c").status_code)
927
928
    def test_pluggable_method_views(self):
929
        app, limiter = self.build_app(default_limits=["1/hour"])
930
931
        class Va(MethodView):
932
            decorators = [limiter.limit("2/second")]
933
934
            def get(self):
935
                return request.method.lower()
936
937
            def post(self):
938
                return request.method.lower()
939
940
        class Vb(MethodView):
941
            decorators = [limiter.limit("1/second, 3/minute")]
942
943
            def get(self):
944
                return request.method.lower()
945
946
        class Vc(MethodView):
947
            def get(self):
948
                return request.method.lower()
949
950
        class Vd(MethodView):
951
            decorators = [limiter.limit("1/minute", methods=['get'])]
952
953
            def get(self):
954 View Code Duplication
                return request.method.lower()
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
955
956
            def post(self):
957
                return request.method.lower()
958
959
        app.add_url_rule("/a", view_func=Va.as_view("a"))
960
        app.add_url_rule("/b", view_func=Vb.as_view("b"))
961
        app.add_url_rule("/c", view_func=Vc.as_view("c"))
962
        app.add_url_rule("/d", view_func=Vd.as_view("d"))
963
964
        with hiro.Timeline().freeze() as timeline:
965
            with app.test_client() as cli:
966
                self.assertEqual(200, cli.get("/a").status_code)
967
                self.assertEqual(200, cli.get("/a").status_code)
968
                self.assertEqual(429, cli.get("/a").status_code)
969
                self.assertEqual(429, cli.post("/a").status_code)
970
                self.assertEqual(200, cli.get("/b").status_code)
971
                timeline.forward(1)
972
                self.assertEqual(200, cli.get("/b").status_code)
973
                timeline.forward(1)
974
                self.assertEqual(200, cli.get("/b").status_code)
975
                timeline.forward(1)
976
                self.assertEqual(429, cli.get("/b").status_code)
977
                self.assertEqual(200, cli.get("/c").status_code)
978
                self.assertEqual(429, cli.get("/c").status_code)
979
                self.assertEqual(200, cli.get("/d").status_code)
980
                self.assertEqual(429, cli.get("/d").status_code)
981
                self.assertEqual(200, cli.post("/d").status_code)
982
                self.assertEqual(200, cli.post("/d").status_code)
983
984
    def test_flask_restful_resource(self):
985
        app, limiter = self.build_app(default_limits=["1/hour"])
986
        api = RestfulApi(app)
987
988
        class Va(Resource):
989
            decorators = [limiter.limit("2/second")]
990
991
            def get(self):
992
                return request.method.lower()
993
994
            def post(self):
995
                return request.method.lower()
996
997
        class Vb(Resource):
998
            decorators = [limiter.limit("1/second, 3/minute")]
999
1000
            def get(self):
1001
                return request.method.lower()
1002
1003
        class Vc(Resource):
1004
            def get(self):
1005
                return request.method.lower()
1006
1007
        api.add_resource(Va, "/a")
1008
        api.add_resource(Vb, "/b")
1009
        api.add_resource(Vc, "/c")
1010
1011
        with hiro.Timeline().freeze() as timeline:
1012
            with app.test_client() as cli:
1013
                self.assertEqual(200, cli.get("/a").status_code)
1014
                self.assertEqual(200, cli.get("/a").status_code)
1015
                self.assertEqual(429, cli.get("/a").status_code)
1016
                self.assertEqual(429, cli.post("/a").status_code)
1017
                self.assertEqual(200, cli.get("/b").status_code)
1018
                timeline.forward(1)
1019
                self.assertEqual(200, cli.get("/b").status_code)
1020
                timeline.forward(1)
1021
                self.assertEqual(200, cli.get("/b").status_code)
1022
                timeline.forward(1)
1023
                self.assertEqual(429, cli.get("/b").status_code)
1024
                self.assertEqual(200, cli.get("/c").status_code)
1025
                self.assertEqual(429, cli.get("/c").status_code)
1026
1027
1028
class FlaskExtTests(FlaskLimiterTestCase):
1029
    def test_reset(self):
1030
        app, limiter = self.build_app({C.GLOBAL_LIMITS: "1 per day"})
1031
1032 View Code Duplication
        @app.route("/")
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1033
        def null():
1034
            return "Hello Reset"
1035
1036
        with app.test_client() as cli:
1037
            cli.get("/")
1038
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
1039
            limiter.reset()
1040
            self.assertEqual("Hello Reset", cli.get("/").data.decode())
1041
            self.assertTrue("1 per 1 day" in cli.get("/").data.decode())
1042
1043
    def test_combined_rate_limits(self):
1044
        app, limiter = self.build_app({
1045
            C.GLOBAL_LIMITS: "1 per hour; 10 per day"
1046
        })
1047
1048
        @app.route("/t1")
1049
        @limiter.limit("100 per hour;10/minute")
1050
        def t1():
1051
            return "t1"
1052
1053
        @app.route("/t2")
1054
        def t2():
1055
            return "t2"
1056
1057
        with hiro.Timeline().freeze() as timeline:
1058
            with app.test_client() as cli:
1059
                self.assertEqual(200, cli.get("/t1").status_code)
1060
                self.assertEqual(200, cli.get("/t2").status_code)
1061
                self.assertEqual(429, cli.get("/t2").status_code)
1062
1063 View Code Duplication
    def test_key_func(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1064
        app, limiter = self.build_app()
1065
1066
        @app.route("/t1")
1067
        @limiter.limit("100 per minute", lambda: "test")
1068
        def t1():
1069
            return "test"
1070
1071
        with hiro.Timeline().freeze() as timeline:
1072
            with app.test_client() as cli:
1073
                for i in range(0, 100):
1074
                    self.assertEqual(
1075
                        200,
1076
                        cli.get(
1077
                            "/t1", headers={
1078
                                "X_FORWARDED_FOR": "127.0.0.2"
1079
                            }
1080
                        ).status_code
1081
                    )
1082
                self.assertEqual(429, cli.get("/t1").status_code)
1083
1084
    def test_logging(self):
1085
        app = Flask(__name__)
1086
        limiter = Limiter(app, key_func=get_remote_address)
1087
        mock_handler = mock.Mock()
1088
        mock_handler.level = logging.INFO
1089
        limiter.logger.addHandler(mock_handler)
1090
1091
        @app.route("/t1")
1092
        @limiter.limit("1/minute")
1093
        def t1():
1094
            return "test"
1095
1096
        with app.test_client() as cli:
1097
            self.assertEqual(200, cli.get("/t1").status_code)
1098
            self.assertEqual(429, cli.get("/t1").status_code)
1099
        self.assertEqual(mock_handler.handle.call_count, 1)
1100
1101
    def test_reuse_logging(self):
1102
        app = Flask(__name__)
1103
        app_handler = mock.Mock()
1104
        app_handler.level = logging.INFO
1105
        app.logger.addHandler(app_handler)
1106
        limiter = Limiter(app, key_func=get_remote_address)
1107
        for handler in app.logger.handlers:
1108
            limiter.logger.addHandler(handler)
1109
1110
        @app.route("/t1")
1111
        @limiter.limit("1/minute")
1112
        def t1():
1113
            return "42"
1114
1115
        with app.test_client() as cli:
1116
            cli.get("/t1")
1117
            cli.get("/t1")
1118
1119
        self.assertEqual(app_handler.handle.call_count, 1)
1120
1121
    def test_disabled_flag(self):
1122
        app, limiter = self.build_app(
1123
            config={C.ENABLED: False}, default_limits=["1/minute"]
1124
        )
1125
1126
        @app.route("/t1")
1127
        def t1():
1128
            return "test"
1129
1130
        @app.route("/t2")
1131
        @limiter.limit("10 per minute")
1132
        def t2():
1133
            return "test"
1134
1135
        with app.test_client() as cli:
1136
            self.assertEqual(cli.get("/t1").status_code, 200)
1137
            self.assertEqual(cli.get("/t1").status_code, 200)
1138
            for i in range(0, 10):
1139
                self.assertEqual(cli.get("/t2").status_code, 200)
1140
            self.assertEqual(cli.get("/t2").status_code, 200)
1141
1142
1143
    def test_multiple_apps(self):
1144
        app1 = Flask(__name__)
1145
        app2 = Flask(__name__)
1146
1147
        limiter = Limiter(
1148
            default_limits=["1/second"], key_func=get_remote_address
1149
        )
1150
        limiter.init_app(app1)
1151
        limiter.init_app(app2)
1152
1153
        @app1.route("/ping")
1154
        def ping():
1155
            return "PONG"
1156
1157
        @app1.route("/slowping")
1158
        @limiter.limit("1/minute")
1159
        def slow_ping():
1160
            return "PONG"
1161
1162
        @app2.route("/ping")
1163
        @limiter.limit("2/second")
1164
        def ping_2():
1165
            return "PONG"
1166
1167
        @app2.route("/slowping")
1168
        @limiter.limit("2/minute")
1169
        def slow_ping_2():
1170
            return "PONG"
1171
1172
        with hiro.Timeline().freeze() as timeline:
1173
            with app1.test_client() as cli:
1174
                self.assertEqual(cli.get("/ping").status_code, 200)
1175
                self.assertEqual(cli.get("/ping").status_code, 429)
1176
                timeline.forward(1)
1177
                self.assertEqual(cli.get("/ping").status_code, 200)
1178
                self.assertEqual(cli.get("/slowping").status_code, 200)
1179
                timeline.forward(59)
1180
                self.assertEqual(cli.get("/slowping").status_code, 429)
1181
                timeline.forward(1)
1182
                self.assertEqual(cli.get("/slowping").status_code, 200)
1183
            with app2.test_client() as cli:
1184
                self.assertEqual(cli.get("/ping").status_code, 200)
1185
                self.assertEqual(cli.get("/ping").status_code, 200)
1186
                self.assertEqual(cli.get("/ping").status_code, 429)
1187
                timeline.forward(1)
1188
                self.assertEqual(cli.get("/ping").status_code, 200)
1189
                self.assertEqual(cli.get("/slowping").status_code, 200)
1190
                timeline.forward(59)
1191
                self.assertEqual(cli.get("/slowping").status_code, 200)
1192
                self.assertEqual(cli.get("/slowping").status_code, 429)
1193
                timeline.forward(1)
1194
                self.assertEqual(cli.get("/slowping").status_code, 200)
1195
1196
    def test_headers_no_breach(self):
1197
        app = Flask(__name__)
1198
        limiter = Limiter(
1199
            app,
1200
            default_limits=["10/minute"],
1201
            headers_enabled=True,
1202
            key_func=get_remote_address
1203
        )
1204
1205
        @app.route("/t1")
1206
        def t1():
1207
            return "test"
1208
1209
        @app.route("/t2")
1210
        @limiter.limit("2/second; 5 per minute; 10/hour")
1211
        def t2():
1212
            return "test"
1213
1214
        with hiro.Timeline().freeze():
1215
            with app.test_client() as cli:
1216
                resp = cli.get("/t1")
1217
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '10')
1218
                self.assertEqual(
1219
                    resp.headers.get('X-RateLimit-Remaining'), '9'
1220
                )
1221
                self.assertEqual(
1222
                    resp.headers.get('X-RateLimit-Reset'),
1223
                    str(int(time.time() + 61))
1224
                )
1225
                self.assertEqual(resp.headers.get('Retry-After'), str(60))
1226
                resp = cli.get("/t2")
1227
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '2')
1228
                self.assertEqual(
1229
                    resp.headers.get('X-RateLimit-Remaining'), '1'
1230
                )
1231
                self.assertEqual(
1232
                    resp.headers.get('X-RateLimit-Reset'),
1233
                    str(int(time.time() + 2))
1234
                )
1235
1236
                self.assertEqual(resp.headers.get('Retry-After'), str(1))
1237
1238
    def test_headers_breach(self):
1239
        app = Flask(__name__)
1240
        limiter = Limiter(
1241
            app,
1242
            default_limits=["10/minute"],
1243
            headers_enabled=True,
1244
            key_func=get_remote_address
1245
        )
1246
1247
        @app.route("/t1")
1248
        @limiter.limit("2/second; 10 per minute; 20/hour")
1249
        def t():
1250
            return "test"
1251
1252
        with hiro.Timeline().freeze() as timeline:
1253
            with app.test_client() as cli:
1254
                for i in range(11):
1255
                    resp = cli.get("/t1")
1256
                    timeline.forward(1)
1257
1258
                self.assertEqual(resp.headers.get('X-RateLimit-Limit'), '10')
1259
                self.assertEqual(
1260
                    resp.headers.get('X-RateLimit-Remaining'), '0'
1261
                )
1262
                self.assertEqual(
1263
                    resp.headers.get('X-RateLimit-Reset'),
1264 View Code Duplication
                    str(int(time.time() + 50))
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1265
                )
1266
                self.assertEqual(resp.headers.get('Retry-After'), str(int(50)))
1267
1268
    def test_retry_after(self):
1269
        app = Flask(__name__)
1270
        _ = Limiter(
1271
            app,
1272
            default_limits=["1/minute"],
1273
            headers_enabled=True,
1274
            key_func=get_remote_address
1275
        )
1276
1277
        @app.route("/t1")
1278
        def t():
1279
            return "test"
1280
1281 View Code Duplication
        with hiro.Timeline().freeze() as timeline:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1282
            with app.test_client() as cli:
1283
                resp = cli.get("/t1")
1284
                retry_after = int(resp.headers.get('Retry-After'))
1285
                self.assertTrue(retry_after > 0)
1286
                timeline.forward(retry_after)
1287
                resp = cli.get("/t1")
1288
                self.assertEqual(resp.status_code, 200)
1289
1290
    def test_custom_headers_from_setter(self):
1291
        app = Flask(__name__)
1292
        limiter = Limiter(
1293
            app,
1294
            default_limits=["10/minute"],
1295
            headers_enabled=True,
1296
            key_func=get_remote_address,
1297
            retry_after='http-date'
1298
        )
1299 View Code Duplication
        limiter._header_mapping[HEADERS.RESET] = 'X-Reset'
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1300
        limiter._header_mapping[HEADERS.LIMIT] = 'X-Limit'
1301
        limiter._header_mapping[HEADERS.REMAINING] = 'X-Remaining'
1302
1303
        @app.route("/t1")
1304
        @limiter.limit("2/second; 10 per minute; 20/hour")
1305
        def t():
1306
            return "test"
1307
1308
        with hiro.Timeline().freeze(0) as timeline:
1309
            with app.test_client() as cli:
1310
                for i in range(11):
1311
                    resp = cli.get("/t1")
1312
                    timeline.forward(1)
1313
1314
                self.assertEqual(resp.headers.get('X-Limit'), '10')
1315
                self.assertEqual(resp.headers.get('X-Remaining'), '0')
1316
                self.assertEqual(
1317
                    resp.headers.get('X-Reset'), str(int(time.time() + 50))
1318
                )
1319
                self.assertEqual(
1320
                    resp.headers.get('Retry-After'),
1321
                    'Thu, 01 Jan 1970 00:01:01 GMT'
1322
                )
1323
1324
    def test_custom_headers_from_config(self):
1325
        app = Flask(__name__)
1326
        app.config.setdefault(C.HEADER_LIMIT, "X-Limit")
1327
        app.config.setdefault(C.HEADER_REMAINING, "X-Remaining")
1328
        app.config.setdefault(C.HEADER_RESET, "X-Reset")
1329
        limiter = Limiter(
1330
            app,
1331
            default_limits=["10/minute"],
1332
            headers_enabled=True,
1333
            key_func=get_remote_address
1334
        )
1335
1336
        @app.route("/t1")
1337
        @limiter.limit("2/second; 10 per minute; 20/hour")
1338
        def t():
1339
            return "test"
1340
1341
        with hiro.Timeline().freeze() as timeline:
1342
            with app.test_client() as cli:
1343
                for i in range(11):
1344
                    resp = cli.get("/t1")
1345
                    timeline.forward(1)
1346
1347
                self.assertEqual(resp.headers.get('X-Limit'), '10')
1348
                self.assertEqual(resp.headers.get('X-Remaining'), '0')
1349
                self.assertEqual(
1350
                    resp.headers.get('X-Reset'), str(int(time.time() + 50))
1351
                )
1352
1353 View Code Duplication
    def test_application_shared_limit(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1354
        app, limiter = self.build_app(application_limits=["2/minute"])
1355
1356
        @app.route("/t1")
1357
        def t1():
1358
            return "route1"
1359
1360
        @app.route("/t2")
1361
        def t2():
1362
            return "route2"
1363
1364
        with hiro.Timeline().freeze():
1365
            with app.test_client() as cli:
1366
                self.assertEqual(200, cli.get("/t1").status_code)
1367
                self.assertEqual(200, cli.get("/t2").status_code)
1368
                self.assertEqual(429, cli.get("/t1").status_code)
1369
1370
    def test_callable_default_limit(self):
1371
        app, limiter = self.build_app(default_limits=[lambda: "1/minute"])
1372
1373
        @app.route("/t1")
1374
        def t1():
1375
            return "t1"
1376
1377
        @app.route("/t2")
1378
        def t2():
1379
            return "t2"
1380
1381
        with hiro.Timeline().freeze():
1382
            with app.test_client() as cli:
1383
                self.assertEqual(cli.get("/t1").status_code, 200)
1384
                self.assertEqual(cli.get("/t2").status_code, 200)
1385
                self.assertEqual(cli.get("/t1").status_code, 429)
1386
                self.assertEqual(cli.get("/t2").status_code, 429)
1387
1388 View Code Duplication
    def test_callable_application_limit(self):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
1389
1390
        app, limiter = self.build_app(application_limits=[lambda: "1/minute"])
1391
1392
        @app.route("/t1")
1393
        def t1():
1394
            return "t1"
1395
1396
        @app.route("/t2")
1397
        def t2():
1398
            return "t2"
1399
1400
        with hiro.Timeline().freeze():
1401
            with app.test_client() as cli:
1402
                self.assertEqual(cli.get("/t1").status_code, 200)
1403
                self.assertEqual(cli.get("/t2").status_code, 429)
1404
1405
    def test_no_auto_check(self):
1406
        app, limiter = self.build_app(auto_check=False)
1407
1408
        @limiter.limit("1/second", per_method=True)
1409
        @app.route("/", methods=["GET", "POST"])
1410
        def root():
1411
            return "root"
1412
1413
        with hiro.Timeline().freeze():
1414
            with app.test_client() as cli:
1415
                self.assertEqual(200, cli.get("/").status_code)
1416
                self.assertEqual(200, cli.get("/").status_code)
1417
1418
        # attach before_request to perform check
1419
        @app.before_request
1420
        def _():
1421
            limiter.check()
1422
1423
        with hiro.Timeline().freeze():
1424
            with app.test_client() as cli:
1425
                self.assertEqual(200, cli.get("/").status_code)
1426
                self.assertEqual(429, cli.get("/").status_code)
1427
1428
1429
    def test_custom_key_prefix(self):
1430
        app1, limiter1 = self.build_app(
1431
            key_prefix="moo", storage_uri="redis://localhost:6379"
1432
        )
1433
        app2, limiter2 = self.build_app({
1434
            C.KEY_PREFIX: "cow"
1435
        },
1436
                                        storage_uri="redis://localhost:6379")
1437
        app3, limiter3 = self.build_app(storage_uri="redis://localhost:6379")
1438
1439
        @app1.route("/test")
1440
        @limiter1.limit("1/day")
1441
        def t1():
1442
            return "app1 test"
1443
1444
        @app2.route("/test")
1445
        @limiter2.limit("1/day")
1446
        def t1():
1447
            return "app1 test"
1448
1449
        @app3.route("/test")
1450
        @limiter3.limit("1/day")
1451
        def t1():
1452
            return "app1 test"
1453
1454
        with app1.test_client() as cli:
1455
            resp = cli.get("/test")
1456
            self.assertEqual(200, resp.status_code)
1457
            resp = cli.get("/test")
1458
            self.assertEqual(429, resp.status_code)
1459
        with app2.test_client() as cli:
1460
            resp = cli.get("/test")
1461
            self.assertEqual(200, resp.status_code)
1462
            resp = cli.get("/test")
1463
            self.assertEqual(429, resp.status_code)
1464
        with app3.test_client() as cli:
1465
            resp = cli.get("/test")
1466
            self.assertEqual(200, resp.status_code)
1467
            resp = cli.get("/test")
1468
            self.assertEqual(429, resp.status_code)
1469