Limiter.exempt()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
c 0
b 0
f 0
dl 0
loc 15
rs 9.4285
1
"""
2
the flask extension
3
"""
4
import datetime
5
import itertools
6
import logging
7
import sys
8
import time
9
import warnings
10
from functools import wraps
11
12
import six
13
from flask import request, current_app, g, Blueprint
14
from limits.errors import ConfigurationError
15
from limits.storage import storage_from_string, MemoryStorage
16
from limits.strategies import STRATEGIES
17
from werkzeug.http import http_date, parse_date
18
19
from flask_limiter.wrappers import Limit, LimitGroup
20
from .errors import RateLimitExceeded
21
from .util import get_ipaddr
22
23
24
class C:
25
    ENABLED = "RATELIMIT_ENABLED"
26
    HEADERS_ENABLED = "RATELIMIT_HEADERS_ENABLED"
27
    STORAGE_URL = "RATELIMIT_STORAGE_URL"
28
    STORAGE_OPTIONS = "RATELIMIT_STORAGE_OPTIONS"
29
    STRATEGY = "RATELIMIT_STRATEGY"
30
    GLOBAL_LIMITS = "RATELIMIT_GLOBAL"
31
    DEFAULT_LIMITS = "RATELIMIT_DEFAULT"
32
    APPLICATION_LIMITS = "RATELIMIT_APPLICATION"
33
    HEADER_LIMIT = "RATELIMIT_HEADER_LIMIT"
34
    HEADER_REMAINING = "RATELIMIT_HEADER_REMAINING"
35
    HEADER_RESET = "RATELIMIT_HEADER_RESET"
36
    SWALLOW_ERRORS = "RATELIMIT_SWALLOW_ERRORS"
37
    IN_MEMORY_FALLBACK = "RATELIMIT_IN_MEMORY_FALLBACK"
38
    HEADER_RETRY_AFTER = "RATELIMIT_HEADER_RETRY_AFTER"
39
    HEADER_RETRY_AFTER_VALUE = "RATELIMIT_HEADER_RETRY_AFTER_VALUE"
40
    KEY_PREFIX = "RATELIMIT_KEY_PREFIX"
41
42
43
class HEADERS:
44
    RESET = 1
45
    REMAINING = 2
46
    LIMIT = 3
47
    RETRY_AFTER = 4
48
49
50
MAX_BACKEND_CHECKS = 5
51
52
53
class Limiter(object):
54
    """
55
    :param app: :class:`flask.Flask` instance to initialize the extension
56
     with.
57
    :param list default_limits: a variable list of strings or callables returning strings denoting global
58
     limits to apply to all routes. :ref:`ratelimit-string` for  more details.
59
    :param list application_limits: a variable list of strings or callables returning strings for limits that
60
     are applied to the entire application (i.e a shared limit for all routes)
61
    :param function key_func: a callable that returns the domain to rate limit by.
62
    :param bool headers_enabled: whether ``X-RateLimit`` response headers are written.
63
    :param str strategy: the strategy to use. refer to :ref:`ratelimit-strategy`
64
    :param str storage_uri: the storage location. refer to :ref:`ratelimit-conf`
65
    :param dict storage_options: kwargs to pass to the storage implementation upon
66
      instantiation.
67
    :param bool auto_check: whether to automatically check the rate limit in the before_request
68
     chain of the application. default ``True``
69
    :param bool swallow_errors: whether to swallow errors when hitting a rate limit.
70
     An exception will still be logged. default ``False``
71
    :param list in_memory_fallback: a variable list of strings or callables returning strings denoting fallback
72
     limits to apply when the storage is down.
73
    :param str key_prefix: prefix prepended to rate limiter keys.
74
    """
75
76
    def __init__(
77
        self,
78
        app=None,
79
        key_func=None,
80
        global_limits=[],
81
        default_limits=[],
82
        application_limits=[],
83
        headers_enabled=False,
84
        strategy=None,
85
        storage_uri=None,
86
        storage_options={},
87
        auto_check=True,
88
        swallow_errors=False,
89
        in_memory_fallback=[],
90
        retry_after=None,
91
        key_prefix="",
92
        enabled=True
93
    ):
94
        self.app = app
95
        self.logger = logging.getLogger("flask-limiter")
96
97
        self.enabled = enabled
98
        self._default_limits = []
99
        self._application_limits = []
100
        self._in_memory_fallback = []
101
        self._exempt_routes = set()
102
        self._request_filters = []
103
        self._headers_enabled = headers_enabled
104
        self._header_mapping = {}
105
        self._retry_after = retry_after
106
        self._strategy = strategy
107
        self._storage_uri = storage_uri
108
        self._storage_options = storage_options
109
        self._auto_check = auto_check
110
        self._swallow_errors = swallow_errors
111
        if not key_func:
112
            warnings.warn(
113
                "Use of the default `get_ipaddr` function is discouraged."
114
                " Please refer to https://flask-limiter.readthedocs.org/#rate-limit-domain"
115
                " for the recommended configuration", UserWarning
116
            )
117
        if global_limits:
118
            self.raise_global_limits_warning()
119
120
        self._key_func = key_func or get_ipaddr
121
        self._key_prefix = key_prefix
122
123
        for limit in set(global_limits + default_limits):
124
            self._default_limits.extend(
125
                [
126
                    LimitGroup(
127
                        limit, self._key_func, None, False, None, None, None
128
                    )
129
                ]
130
            )
131
        for limit in application_limits:
132
            self._application_limits.extend(
133
                [
134
                    LimitGroup(
135
                        limit, self._key_func, "global", False, None, None,
136
                        None
137
                    )
138
                ]
139
            )
140
        for limit in in_memory_fallback:
141
            self._in_memory_fallback.extend(
142
                [
143
                    LimitGroup(
144
                        limit, self._key_func, None, False, None, None, None
145
                    )
146
                ]
147
            )
148
        self._route_limits = {}
149
        self._dynamic_route_limits = {}
150
        self._blueprint_limits = {}
151
        self._blueprint_dynamic_limits = {}
152
        self._blueprint_exempt = set()
153
        self._storage = self._limiter = None
154
        self._storage_dead = False
155
        self._fallback_limiter = None
156
        self.__check_backend_count = 0
157
        self.__last_check_backend = time.time()
158
        self.__marked_for_limiting = {}
159
160
        class BlackHoleHandler(logging.StreamHandler):
161
            def emit(*_):
162
                return
163
164
        self.logger.addHandler(BlackHoleHandler())
165
        if app:
166
            self.init_app(app)
167
168
    def init_app(self, app):
169
        """
170
        :param app: :class:`flask.Flask` instance to rate limit.
171
        """
172
        self.enabled = app.config.setdefault(C.ENABLED, self.enabled)
173
        self._swallow_errors = app.config.setdefault(
174
            C.SWALLOW_ERRORS, self._swallow_errors
175
        )
176
        self._headers_enabled = (
177
            self._headers_enabled
178
            or app.config.setdefault(C.HEADERS_ENABLED, False)
179
        )
180
        self._storage_options.update(app.config.get(C.STORAGE_OPTIONS, {}))
181
        self._storage = storage_from_string(
182
            self._storage_uri
183
            or app.config.setdefault(C.STORAGE_URL, 'memory://'),
184
            **self._storage_options
185
        )
186
        strategy = (
187
            self._strategy
188
            or app.config.setdefault(C.STRATEGY, 'fixed-window')
189
        )
190
        if strategy not in STRATEGIES:
191
            raise ConfigurationError(
192
                "Invalid rate limiting strategy %s" % strategy
193
            )
194
        self._limiter = STRATEGIES[strategy](self._storage)
195
        self._header_mapping.update(
196
            {
197
                HEADERS.RESET:
198
                self._header_mapping.get(HEADERS.RESET, None)
199
                or app.config.setdefault(C.HEADER_RESET, "X-RateLimit-Reset"),
200
                HEADERS.REMAINING:
201
                self._header_mapping.get(HEADERS.REMAINING, None)
202
                or app.config.setdefault(
203
                    C.HEADER_REMAINING, "X-RateLimit-Remaining"
204
                ),
205
                HEADERS.LIMIT:
206
                self._header_mapping.get(HEADERS.LIMIT, None)
207
                or app.config.setdefault(C.HEADER_LIMIT, "X-RateLimit-Limit"),
208
                HEADERS.RETRY_AFTER:
209
                self._header_mapping.get(HEADERS.RETRY_AFTER, None)
210
                or app.config.setdefault(C.HEADER_RETRY_AFTER, "Retry-After"),
211
            }
212
        )
213
        self._retry_after = (
214
            self._retry_after or app.config.get(C.HEADER_RETRY_AFTER_VALUE)
215
        )
216
        self._key_prefix = (self._key_prefix or app.config.get(C.KEY_PREFIX))
217
        app_limits = app.config.get(C.APPLICATION_LIMITS, None)
218
        if not self._application_limits and app_limits:
219
            self._application_limits = [
220
                LimitGroup(
221
                    app_limits, self._key_func, "global", False, None, None,
222
                    None
223
                )
224
            ]
225
226
        if app.config.get(C.GLOBAL_LIMITS, None):
227
            self.raise_global_limits_warning()
228
        conf_limits = app.config.get(
229
            C.GLOBAL_LIMITS, app.config.get(C.DEFAULT_LIMITS, None)
230
        )
231
        if not self._default_limits and conf_limits:
232
            self._default_limits = [
233
                LimitGroup(
234
                    conf_limits, self._key_func, None, False, None, None, None
235
                )
236
            ]
237
        fallback_limits = app.config.get(C.IN_MEMORY_FALLBACK, None)
238
        if not self._in_memory_fallback and fallback_limits:
239
            self._in_memory_fallback = [
240
                LimitGroup(
241
                    fallback_limits, self._key_func, None, False, None, None,
242
                    None
243
                )
244
            ]
245
246
        if self._in_memory_fallback:
247
            self._fallback_storage = MemoryStorage()
248
            self._fallback_limiter = STRATEGIES[strategy](
249
                self._fallback_storage
250
            )
251
252
        # purely for backward compatibility as stated in flask documentation
253
        if not hasattr(app, 'extensions'):
254
            app.extensions = {}  # pragma: no cover
255
256
        if not app.extensions.get('limiter'):
257
            if self._auto_check:
258
                app.before_request(self.__check_request_limit)
259
            app.after_request(self.__inject_headers)
260
261
        app.extensions['limiter'] = self
262
263
    def __should_check_backend(self):
264
        if self.__check_backend_count > MAX_BACKEND_CHECKS:
265
            self.__check_backend_count = 0
266
        if time.time() - self.__last_check_backend > pow(
267
            2, self.__check_backend_count
268
        ):
269
            self.__last_check_backend = time.time()
270
            self.__check_backend_count += 1
271
            return True
272
        return False
273
274
    def check(self):
275
        """
276
        check the limits for the current request
277
278
        :raises: RateLimitExceeded
279
        """
280
        self.__check_request_limit(False)
281
282
    def reset(self):
283
        """
284
        resets the storage if it supports being reset
285
        """
286
        try:
287
            self._storage.reset()
288
            self.logger.info("Storage has been reset and all limits cleared")
289
        except NotImplementedError:
290
            self.logger.warning(
291
                "This storage type does not support being reset"
292
            )
293
294
    @property
295
    def limiter(self):
296
        if self._storage_dead and self._in_memory_fallback:
297
            return self._fallback_limiter
298
        else:
299
            return self._limiter
300
301
    def __inject_headers(self, response):
302
        current_limit = getattr(g, 'view_rate_limit', None)
303
        if self.enabled and self._headers_enabled and current_limit:
304
            window_stats = self.limiter.get_window_stats(*current_limit)
305
            reset_in = 1 + window_stats[0]
306
            response.headers.add(
307
                self._header_mapping[HEADERS.LIMIT],
308
                str(current_limit[0].amount)
309
            )
310
            response.headers.add(
311
                self._header_mapping[HEADERS.REMAINING], window_stats[1]
312
            )
313
            response.headers.add(self._header_mapping[HEADERS.RESET], reset_in)
314
315
            # response may have an existing retry after
316
            existing_retry_after_header = response.headers.get('Retry-After')
317
318
            if existing_retry_after_header is not None:
319
                # might be in http-date format
320
                retry_after = parse_date(existing_retry_after_header)
321
322
                # parse_date failure returns None
323
                if retry_after is None:
324
                    retry_after = time.time() + int(existing_retry_after_header)
325
326
                if isinstance(retry_after, datetime.datetime):
327
                    retry_after = time.mktime(retry_after.timetuple())
328
329
                reset_in = max(retry_after, reset_in)
330
331
            # set the header instead of using add
332
            response.headers.set(
333
                self._header_mapping[HEADERS.RETRY_AFTER],
334
                self._retry_after == 'http-date' and http_date(reset_in)
335
                or int(reset_in - time.time())
336
            )
337
        return response
338
339
    def __evaluate_limits(self, endpoint, limits):
340
        failed_limit = None
341
        limit_for_header = None
342
        for lim in limits:
343
            limit_scope = lim.scope or endpoint
344
            if lim.is_exempt:
345
                return
346
            if lim.methods is not None and request.method.lower(
347
            ) not in lim.methods:
348
                return
349
            if lim.per_method:
350
                limit_scope += ":%s" % request.method
351
            limit_key = lim.key_func()
352
353
            args = [limit_key, limit_scope]
354
            if all(args):
355
                if self._key_prefix:
356
                    args = [self._key_prefix] + args
357
                if not limit_for_header or lim.limit < limit_for_header[0]:
358
                    limit_for_header = [lim.limit] + args
359
                if not self.limiter.hit(lim.limit, *args):
360
                    self.logger.warning(
361
                        "ratelimit %s (%s) exceeded at endpoint: %s",
362
                        lim.limit, limit_key, limit_scope
363
                    )
364
                    failed_limit = lim
365
                    limit_for_header = [lim.limit] + args
366
                    break
367
            else:
368
                self.logger.error(
369
                    "Skipping limit: %s. Empty value found in parameters.",
370
                    lim.limit
371
                )
372
                continue
373
        g.view_rate_limit = limit_for_header
374
375
        if failed_limit:
376
            if failed_limit.error_message:
377
                exc_description = failed_limit.error_message if not callable(
378
                    failed_limit.error_message
379
                ) else failed_limit.error_message()
380
            else:
381
                exc_description = six.text_type(failed_limit.limit)
382
            raise RateLimitExceeded(exc_description)
383
384
    def __check_request_limit(self, in_middleware=True):
385
        endpoint = request.endpoint or ""
386
        view_func = current_app.view_functions.get(endpoint, None)
387
        name = (
388
            "%s.%s" % (view_func.__module__, view_func.__name__)
389
            if view_func else ""
390
        )
391
        if (not request.endpoint
392
            or not self.enabled
393
            or view_func == current_app.send_static_file
394
            or name in self._exempt_routes
395
            or request.blueprint in self._blueprint_exempt
396
            or any(fn() for fn in self._request_filters)
397
            or g.get("_rate_limiting_complete")
398
        ):
399
            return
400
        limits, dynamic_limits = [], []
401
402
        # this is to ensure backward compatibility with behavior that
403
        # existed accidentally, i.e::
404
        #
405
        # @limiter.limit(...)
406
        # @app.route('...')
407
        # def func(...):
408
        #
409
        # The above setup would work in pre 1.0 versions because the decorator
410
        # was not acting immediately and instead simply registering the rate
411
        # limiting. The correct way to use the decorator is to wrap
412
        # the limiter with the route, i.e::
413
        #
414
        # @app.route(...)
415
        # @limiter.limit(...)
416
        # def func(...):
417
418
        implicit_decorator = view_func in self.__marked_for_limiting.get(
419
            name, []
420
        )
421
422
        if not in_middleware or implicit_decorator:
423
            limits = (
424
                name in self._route_limits and self._route_limits[name] or []
425
            )
426
            dynamic_limits = []
427
            if name in self._dynamic_route_limits:
428
                for lim in self._dynamic_route_limits[name]:
429
                    try:
430
                        dynamic_limits.extend(list(lim))
431
                    except ValueError as e:
432
                        self.logger.error(
433
                            "failed to load ratelimit for view function %s (%s)",
434
                            name, e
435
                        )
436
        if request.blueprint:
437
            if (request.blueprint in self._blueprint_dynamic_limits
438
                and not dynamic_limits
439
            ):
440
                for limit_group in self._blueprint_dynamic_limits[
441
                    request.blueprint
442
                ]:
443
                    try:
444
                        dynamic_limits.extend(
445
                            [
446
                                Limit(
447
                                    limit.limit, limit.key_func, limit.scope,
448
                                    limit.per_method, limit.methods,
449
                                    limit.error_message, limit.exempt_when
450
                                ) for limit in limit_group
451
                            ]
452
                        )
453
                    except ValueError as e:
454
                        self.logger.error(
455
                            "failed to load ratelimit for blueprint %s (%s)",
456
                            request.blueprint, e
457
                        )
458
            if request.blueprint in self._blueprint_limits and not limits:
459
                limits.extend(self._blueprint_limits[request.blueprint])
460
461
        try:
462
            all_limits = []
463
            if self._storage_dead and self._fallback_limiter:
464
                if in_middleware and name in self.__marked_for_limiting:
465
                    pass
466
                else:
467
                    if self.__should_check_backend() and self._storage.check():
468
                        self.logger.info("Rate limit storage recovered")
469
                        self._storage_dead = False
470
                        self.__check_backend_count = 0
471
                    else:
472
                        all_limits = list(
473
                            itertools.chain(*self._in_memory_fallback)
474
                        )
475
            if not all_limits:
476
                route_limits = limits + dynamic_limits
477
                all_limits = list(itertools.chain(*self._application_limits)) if in_middleware else []
478
                all_limits += route_limits
479
                if (
480
                    not route_limits
481
                    and not (in_middleware and name in self.__marked_for_limiting)
482
                    or implicit_decorator
483
                ):
484
                        all_limits += list(itertools.chain(*self._default_limits))
485
            self.__evaluate_limits(endpoint, all_limits)
486
        except Exception as e:  # no qa
487
            if isinstance(e, RateLimitExceeded):
488
                six.reraise(*sys.exc_info())
489
            if self._in_memory_fallback and not self._storage_dead:
490
                self.logger.warn(
491
                    "Rate limit storage unreachable - falling back to"
492
                    " in-memory storage"
493
                )
494
                self._storage_dead = True
495
                self.__check_request_limit(in_middleware)
496
            else:
497
                if self._swallow_errors:
498
                    self.logger.exception(
499
                        "Failed to rate limit. Swallowing error"
500
                    )
501
                else:
502
                    six.reraise(*sys.exc_info())
503
504
    def __limit_decorator(
505
        self,
506
        limit_value,
507
        key_func=None,
508
        shared=False,
509
        scope=None,
510
        per_method=False,
511
        methods=None,
512
        error_message=None,
513
        exempt_when=None,
514
    ):
515
        _scope = scope if shared else None
516
517
        def _inner(obj):
518
            func = key_func or self._key_func
519
            is_route = not isinstance(obj, Blueprint)
520
            name = "%s.%s" % (
521
                obj.__module__, obj.__name__
522
            ) if is_route else obj.name
523
            dynamic_limit, static_limits = None, []
524
            if callable(limit_value):
525
                dynamic_limit = LimitGroup(
526
                    limit_value, func, _scope, per_method, methods,
527
                    error_message, exempt_when
528
                )
529
            else:
530
                try:
531
                    static_limits = list(
532
                        LimitGroup(
533
                            limit_value, func, _scope, per_method, methods,
534
                            error_message, exempt_when
535
                        )
536
                    )
537
                except ValueError as e:
538
                    self.logger.error(
539
                        "failed to configure %s %s (%s)", "view function"
540
                        if is_route else "blueprint", name, e
541
                    )
542
            if isinstance(obj, Blueprint):
543
                if dynamic_limit:
544
                    self._blueprint_dynamic_limits.setdefault(
545
                        name, []
546
                    ).append(dynamic_limit)
547
                else:
548
                    self._blueprint_limits.setdefault(
549
                        name, []
550
                    ).extend(static_limits)
551
            else:
552
                self.__marked_for_limiting.setdefault(name, []).append(obj)
553
                if dynamic_limit:
554
                    self._dynamic_route_limits.setdefault(
555
                        name, []
556
                    ).append(dynamic_limit)
557
                else:
558
                    self._route_limits.setdefault(
559
                        name, []
560
                    ).extend(static_limits)
561
562
                @wraps(obj)
563
                def __inner(*a, **k):
564
                    if self._auto_check and not g.get("_rate_limiting_complete"):
565
                        self.__check_request_limit(False)
566
                        g._rate_limiting_complete = True
567
                    return obj(*a, **k)
568
                return __inner
569
        return _inner
570
571
    def limit(
572
        self,
573
        limit_value,
574
        key_func=None,
575
        per_method=False,
576
        methods=None,
577
        error_message=None,
578
        exempt_when=None,
579
    ):
580
        """
581
        decorator to be used for rate limiting individual routes or blueprints.
582
583
        :param limit_value: rate limit string or a callable that returns a string.
584
         :ref:`ratelimit-string` for more details.
585
        :param function key_func: function/lambda to extract the unique identifier for
586
         the rate limit. defaults to remote address of the request.
587
        :param bool per_method: whether the limit is sub categorized into the http
588
         method of the request.
589
        :param list methods: if specified, only the methods in this list will be rate
590
         limited (default: None).
591
        :param error_message: string (or callable that returns one) to override the
592
         error message used in the response.
593
        :param exempt_when:
594
        :return:
595
        """
596
        return self.__limit_decorator(
597
            limit_value,
598
            key_func,
599
            per_method=per_method,
600
            methods=methods,
601
            error_message=error_message,
602
            exempt_when=exempt_when,
603
        )
604
605
    def shared_limit(
606
        self,
607
        limit_value,
608
        scope,
609
        key_func=None,
610
        error_message=None,
611
        exempt_when=None,
612
    ):
613
        """
614
        decorator to be applied to multiple routes sharing the same rate limit.
615
616
        :param limit_value: rate limit string or a callable that returns a string.
617
         :ref:`ratelimit-string` for more details.
618
        :param scope: a string or callable that returns a string
619
         for defining the rate limiting scope.
620
        :param function key_func: function/lambda to extract the unique identifier for
621
         the rate limit. defaults to remote address of the request.
622
        :param error_message: string (or callable that returns one) to override the
623
         error message used in the response.
624
        :param exempt_when:
625
        """
626
        return self.__limit_decorator(
627
            limit_value,
628
            key_func,
629
            True,
630
            scope,
631
            error_message=error_message,
632
            exempt_when=exempt_when,
633
        )
634
635
    def exempt(self, obj):
636
        """
637
        decorator to mark a view or all views in a blueprint as exempt from rate limits.
638
        """
639
        if not isinstance(obj, Blueprint):
640
            name = "%s.%s" % (obj.__module__, obj.__name__)
641
642
            @wraps(obj)
643
            def __inner(*a, **k):
644
                return obj(*a, **k)
645
646
            self._exempt_routes.add(name)
647
            return __inner
648
        else:
649
            self._blueprint_exempt.add(obj.name)
650
651
    def request_filter(self, fn):
652
        """
653
        decorator to mark a function as a filter to be executed
654
        to check if the request is exempt from rate limiting.
655
        """
656
        self._request_filters.append(fn)
657
        return fn
658
659
    def raise_global_limits_warning(self):
660
        warnings.warn(
661
            "global_limits was a badly name configuration since it is actually a default limit and not a "
662
            " globally shared limit. Use default_limits if you want to provide a default or use application_limits "
663
            " if you intend to really have a global shared limit", UserWarning
664
        )
665