Completed
Push — master ( 08e939...7140d7 )
by Ali-Akber
28s
created

Limiter.__inner()   A

Complexity

Conditions 1

Size

Total Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
"""
2
the flask extension
3
"""
4
import itertools
5
import logging
6
import sys
7
import time
8
import warnings
9
from functools import wraps
10
11
import six
12
from flask import request, current_app, g, Blueprint
13
from limits.errors import ConfigurationError
14
from limits.storage import storage_from_string, MemoryStorage
15
from limits.strategies import STRATEGIES
16
from werkzeug.http import http_date
17
18
from flask_limiter.wrappers import Limit, LimitGroup
19
from .errors import RateLimitExceeded
20
from .util import get_ipaddr
21
22
23
class C:
24
    ENABLED = "RATELIMIT_ENABLED"
25
    HEADERS_ENABLED = "RATELIMIT_HEADERS_ENABLED"
26
    STORAGE_URL = "RATELIMIT_STORAGE_URL"
27
    STORAGE_OPTIONS = "RATELIMIT_STORAGE_OPTIONS"
28
    STRATEGY = "RATELIMIT_STRATEGY"
29
    GLOBAL_LIMITS = "RATELIMIT_GLOBAL"
30
    DEFAULT_LIMITS = "RATELIMIT_DEFAULT"
31
    APPLICATION_LIMITS = "RATELIMIT_APPLICATION"
32
    HEADER_LIMIT = "RATELIMIT_HEADER_LIMIT"
33
    HEADER_REMAINING = "RATELIMIT_HEADER_REMAINING"
34
    HEADER_RESET = "RATELIMIT_HEADER_RESET"
35
    SWALLOW_ERRORS = "RATELIMIT_SWALLOW_ERRORS"
36
    IN_MEMORY_FALLBACK = "RATELIMIT_IN_MEMORY_FALLBACK"
37
    HEADER_RETRY_AFTER = "RATELIMIT_HEADER_RETRY_AFTER"
38
    HEADER_RETRY_AFTER_VALUE = "RATELIMIT_HEADER_RETRY_AFTER_VALUE"
39
    KEY_PREFIX = "RATELIMIT_KEY_PREFIX"
40
41
42
class HEADERS:
43
    RESET = 1
44
    REMAINING = 2
45
    LIMIT = 3
46
    RETRY_AFTER = 4
47
48
49
MAX_BACKEND_CHECKS = 5
50
51
52
class Limiter(object):
53
    """
54
    :param app: :class:`flask.Flask` instance to initialize the extension
55
     with.
56
    :param list default_limits: a variable list of strings or callables returning strings denoting global
57
     limits to apply to all routes. :ref:`ratelimit-string` for  more details.
58
    :param list application_limits: a variable list of strings or callables returning strings for limits that
59
     are applied to the entire application (i.e a shared limit for all routes)
60
    :param function key_func: a callable that returns the domain to rate limit by.
61
    :param bool headers_enabled: whether ``X-RateLimit`` response headers are written.
62
    :param str strategy: the strategy to use. refer to :ref:`ratelimit-strategy`
63
    :param str storage_uri: the storage location. refer to :ref:`ratelimit-conf`
64
    :param dict storage_options: kwargs to pass to the storage implementation upon
65
      instantiation.
66
    :param bool auto_check: whether to automatically check the rate limit in the before_request
67
     chain of the application. default ``True``
68
    :param bool swallow_errors: whether to swallow errors when hitting a rate limit.
69
     An exception will still be logged. default ``False``
70
    :param list in_memory_fallback: a variable list of strings or callables returning strings denoting fallback
71
     limits to apply when the storage is down.
72
    :param str key_prefix: prefix prepended to rate limiter keys.
73
    """
74
75
    def __init__(
76
        self,
77
        app=None,
78
        key_func=None,
79
        global_limits=[],
80
        default_limits=[],
81
        application_limits=[],
82
        headers_enabled=False,
83
        strategy=None,
84
        storage_uri=None,
85
        storage_options={},
86
        auto_check=True,
87
        swallow_errors=False,
88
        in_memory_fallback=[],
89
        retry_after=None,
90
        key_prefix=""
91
    ):
92
        self.app = app
93
        self.logger = logging.getLogger("flask-limiter")
94
95
        self.enabled = True
96
        self._default_limits = []
97
        self._application_limits = []
98
        self._in_memory_fallback = []
99
        self._exempt_routes = set()
100
        self._request_filters = []
101
        self._headers_enabled = headers_enabled
102
        self._header_mapping = {}
103
        self._retry_after = retry_after
104
        self._strategy = strategy
105
        self._storage_uri = storage_uri
106
        self._storage_options = storage_options
107
        self._auto_check = auto_check
108
        self._swallow_errors = swallow_errors
109
        if not key_func:
110
            warnings.warn(
111
                "Use of the default `get_ipaddr` function is discouraged."
112
                " Please refer to https://flask-limiter.readthedocs.org/#rate-limit-domain"
113
                " for the recommended configuration", UserWarning
114
            )
115
        if global_limits:
116
            self.raise_global_limits_warning()
117
118
        self._key_func = key_func or get_ipaddr
119
        self._key_prefix = key_prefix
120
121
        for limit in set(global_limits + default_limits):
122
            self._default_limits.extend(
123
                [
124
                    LimitGroup(
125
                        limit, self._key_func, None, False, None, None, None
126
                    )
127
                ]
128
            )
129
        for limit in application_limits:
130
            self._application_limits.extend(
131
                [
132
                    LimitGroup(
133
                        limit, self._key_func, "global", False, None, None,
134
                        None
135
                    )
136
                ]
137
            )
138
        for limit in in_memory_fallback:
139
            self._in_memory_fallback.extend(
140
                [
141
                    LimitGroup(
142
                        limit, self._key_func, None, False, None, None, None
143
                    )
144
                ]
145
            )
146
        self._route_limits = {}
147
        self._dynamic_route_limits = {}
148
        self._blueprint_limits = {}
149
        self._blueprint_dynamic_limits = {}
150
        self._blueprint_exempt = set()
151
        self._storage = self._limiter = None
152
        self._storage_dead = False
153
        self._fallback_limiter = None
154
        self.__check_backend_count = 0
155
        self.__last_check_backend = time.time()
156
157
        class BlackHoleHandler(logging.StreamHandler):
158
            def emit(*_):
159
                return
160
161
        self.logger.addHandler(BlackHoleHandler())
162
        if app:
163
            self.init_app(app)
164
165
    def init_app(self, app):
166
        """
167
        :param app: :class:`flask.Flask` instance to rate limit.
168
        """
169
        self.enabled = app.config.setdefault(C.ENABLED, True)
170
        self._swallow_errors = app.config.setdefault(
171
            C.SWALLOW_ERRORS, self._swallow_errors
172
        )
173
        self._headers_enabled = (
174
            self._headers_enabled
175
            or app.config.setdefault(C.HEADERS_ENABLED, False)
176
        )
177
        self._storage_options.update(app.config.get(C.STORAGE_OPTIONS, {}))
178
        self._storage = storage_from_string(
179
            self._storage_uri
180
            or app.config.setdefault(C.STORAGE_URL, 'memory://'),
181
            **self._storage_options
182
        )
183
        strategy = (
184
            self._strategy
185
            or app.config.setdefault(C.STRATEGY, 'fixed-window')
186
        )
187
        if strategy not in STRATEGIES:
188
            raise ConfigurationError(
189
                "Invalid rate limiting strategy %s" % strategy
190
            )
191
        self._limiter = STRATEGIES[strategy](self._storage)
192
        self._header_mapping.update(
193
            {
194
                HEADERS.RESET:
195
                self._header_mapping.get(HEADERS.RESET, None)
196
                or app.config.setdefault(C.HEADER_RESET, "X-RateLimit-Reset"),
197
                HEADERS.REMAINING:
198
                self._header_mapping.get(HEADERS.REMAINING, None)
199
                or app.config.setdefault(
200
                    C.HEADER_REMAINING, "X-RateLimit-Remaining"
201
                ),
202
                HEADERS.LIMIT:
203
                self._header_mapping.get(HEADERS.LIMIT, None)
204
                or app.config.setdefault(C.HEADER_LIMIT, "X-RateLimit-Limit"),
205
                HEADERS.RETRY_AFTER:
206
                self._header_mapping.get(HEADERS.RETRY_AFTER, None)
207
                or app.config.setdefault(C.HEADER_RETRY_AFTER, "Retry-After"),
208
            }
209
        )
210
        self._retry_after = (
211
            self._retry_after or app.config.get(C.HEADER_RETRY_AFTER_VALUE)
212
        )
213
        self._key_prefix = (self._key_prefix or app.config.get(C.KEY_PREFIX))
214
        app_limits = app.config.get(C.APPLICATION_LIMITS, None)
215
        if not self._application_limits and app_limits:
216
            self._application_limits = [
217
                LimitGroup(
218
                    app_limits, self._key_func, "global", False, None, None,
219
                    None
220
                )
221
            ]
222
223
        if app.config.get(C.GLOBAL_LIMITS, None):
224
            self.raise_global_limits_warning()
225
        conf_limits = app.config.get(
226
            C.GLOBAL_LIMITS, app.config.get(C.DEFAULT_LIMITS, None)
227
        )
228
        if not self._default_limits and conf_limits:
229
            self._default_limits = [
230
                LimitGroup(
231
                    conf_limits, self._key_func, None, False, None, None, None
232
                )
233
            ]
234
        fallback_limits = app.config.get(C.IN_MEMORY_FALLBACK, None)
235
        if not self._in_memory_fallback and fallback_limits:
236
            self._in_memory_fallback = [
237
                LimitGroup(
238
                    fallback_limits, self._key_func, None, False, None, None,
239
                    None
240
                )
241
            ]
242
        if self._auto_check:
243
            app.before_request(self.__check_request_limit)
244
        app.after_request(self.__inject_headers)
245
246
        if self._in_memory_fallback:
247
            self._fallback_storage = MemoryStorage()
248
            self._fallback_limiter = STRATEGIES[strategy](
249
                self._fallback_storage
250
            )
251
252
        # purely for backward compatibility as stated in flask documentation
253
        if not hasattr(app, 'extensions'):
254
            app.extensions = {}  # pragma: no cover
255
        app.extensions['limiter'] = self
256
257
    def __should_check_backend(self):
258
        if self.__check_backend_count > MAX_BACKEND_CHECKS:
259
            self.__check_backend_count = 0
260
        if time.time() - self.__last_check_backend > pow(
261
            2, self.__check_backend_count
262
        ):
263
            self.__last_check_backend = time.time()
264
            self.__check_backend_count += 1
265
            return True
266
        return False
267
268
    def check(self):
269
        """
270
        check the limits for the current request
271
272
        :raises: RateLimitExceeded
273
        """
274
        self.__check_request_limit()
275
276
    def reset(self):
277
        """
278
        resets the storage if it supports being reset
279
        """
280
        try:
281
            self._storage.reset()
282
            self.logger.info("Storage has been reset and all limits cleared")
283
        except NotImplementedError:
284
            self.logger.warning(
285
                "This storage type does not support being reset"
286
            )
287
288
    @property
289
    def limiter(self):
290
        if self._storage_dead and self._in_memory_fallback:
291
            return self._fallback_limiter
292
        else:
293
            return self._limiter
294
295
    def __inject_headers(self, response):
296
        current_limit = getattr(g, 'view_rate_limit', None)
297
        if self.enabled and self._headers_enabled and current_limit:
298
            window_stats = self.limiter.get_window_stats(*current_limit)
299
            reset_in = 1 + window_stats[0]
300
            response.headers.add(
301
                self._header_mapping[HEADERS.LIMIT],
302
                str(current_limit[0].amount)
303
            )
304
            response.headers.add(
305
                self._header_mapping[HEADERS.REMAINING], window_stats[1]
306
            )
307
            response.headers.add(self._header_mapping[HEADERS.RESET], reset_in)
308
            response.headers.add(
309
                self._header_mapping[HEADERS.RETRY_AFTER],
310
                self._retry_after == 'http-date' and http_date(reset_in)
311
                or int(reset_in - time.time())
312
            )
313
        return response
314
315
    def __check_request_limit(self):
316
        endpoint = request.endpoint or ""
317
        view_func = current_app.view_functions.get(endpoint, None)
318
        name = (
319
            "%s.%s" % (view_func.__module__, view_func.__name__)
320
            if view_func else ""
321
        )
322
        if (not request.endpoint
323
            or not self.enabled
324
            or view_func == current_app.send_static_file
325
            or name in self._exempt_routes
326
            or request.blueprint in self._blueprint_exempt
327
            or any(fn() for fn in self._request_filters)
328
        ):
329
            return
330
        limits = (
331
            name in self._route_limits and self._route_limits[name] or []
332
        )
333
        dynamic_limits = []
334
        if name in self._dynamic_route_limits:
335
            for lim in self._dynamic_route_limits[name]:
336
                try:
337
                    dynamic_limits.extend(list(lim))
338
                except ValueError as e:
339
                    self.logger.error(
340
                        "failed to load ratelimit for view function %s (%s)",
341
                        name, e
342
                    )
343
        if request.blueprint:
344
            if (request.blueprint in self._blueprint_dynamic_limits
345
                and not dynamic_limits
346
            ):
347
                for limit_group in self._blueprint_dynamic_limits[
348
                    request.blueprint
349
                ]:
350
                    try:
351
                        dynamic_limits.extend(
352
                            [
353
                                Limit(
354
                                    limit.limit, limit.key_func, limit.scope,
355
                                    limit.per_method, limit.methods,
356
                                    limit.error_message, limit.exempt_when
357
                                ) for limit in limit_group
358
                            ]
359
                        )
360
                    except ValueError as e:
361
                        self.logger.error(
362
                            "failed to load ratelimit for blueprint %s (%s)",
363
                            request.blueprint, e
364
                        )
365
            if request.blueprint in self._blueprint_limits and not limits:
366
                limits.extend(self._blueprint_limits[request.blueprint])
367
368
        failed_limit = None
369
        limit_for_header = None
370
        try:
371
            all_limits = []
372
            if self._storage_dead and self._fallback_limiter:
373
                if self.__should_check_backend() and self._storage.check():
374
                    self.logger.info("Rate limit storage recovered")
375
                    self._storage_dead = False
376
                    self.__check_backend_count = 0
377
                else:
378
                    all_limits = list(
379
                        itertools.chain(*self._in_memory_fallback)
380
                    )
381
            if not all_limits:
382
                all_limits = itertools.chain(
383
                    itertools.chain(*self._application_limits),
384
                    (limits + dynamic_limits)
385
                    or itertools.chain(*self._default_limits)
386
                )
387
            for lim in all_limits:
388
                limit_scope = lim.scope or endpoint
389
                if lim.is_exempt:
390
                    return
391
                if lim.methods is not None and request.method.lower(
392
                ) not in lim.methods:
393
                    return
394
                if lim.per_method:
395
                    limit_scope += ":%s" % request.method
396
                limit_key = lim.key_func()
397
398
                args = [limit_key, limit_scope]
399
                if all(args):
400
                    if self._key_prefix:
401
                        args = [self._key_prefix] + args
402
                    if not limit_for_header or lim.limit < limit_for_header[0]:
403
                        limit_for_header = [lim.limit] + args
404
                    if not self.limiter.hit(lim.limit, *args):
405
                        self.logger.warning(
406
                            "ratelimit %s (%s) exceeded at endpoint: %s",
407
                            lim.limit, limit_key, limit_scope
408
                        )
409
                        failed_limit = lim
410
                        limit_for_header = [lim.limit] + args
411
                        break
412
                else:
413
                    self.logger.error(
414
                        "Skipping limit: %s. Empty value found in parameters.",
415
                        lim.limit
416
                    )
417
                    continue
418
            g.view_rate_limit = limit_for_header
419
420
            if failed_limit:
421
                if failed_limit.error_message:
422
                    exc_description = failed_limit.error_message if not callable(
423
                        failed_limit.error_message
424
                    ) else failed_limit.error_message()
425
                else:
426
                    exc_description = six.text_type(failed_limit.limit)
427
                raise RateLimitExceeded(exc_description)
428
        except Exception as e:  # no qa
429
            if isinstance(e, RateLimitExceeded):
430
                six.reraise(*sys.exc_info())
431
            if self._in_memory_fallback and not self._storage_dead:
432
                self.logger.warn(
433
                    "Rate limit storage unreachable - falling back to"
434
                    " in-memory storage"
435
                )
436
                self._storage_dead = True
437
                self.__check_request_limit()
438
            else:
439
                if self._swallow_errors:
440
                    self.logger.exception(
441
                        "Failed to rate limit. Swallowing error"
442
                    )
443
                else:
444
                    six.reraise(*sys.exc_info())
445
446
    def __limit_decorator(
447
        self,
448
        limit_value,
449
        key_func=None,
450
        shared=False,
451
        scope=None,
452
        per_method=False,
453
        methods=None,
454
        error_message=None,
455
        exempt_when=None
456
    ):
457
        _scope = scope if shared else None
458
459
        def _inner(obj):
460
            func = key_func or self._key_func
461
            is_route = not isinstance(obj, Blueprint)
462
            name = "%s.%s" % (
463
                obj.__module__, obj.__name__
464
            ) if is_route else obj.name
465
            dynamic_limit, static_limits = None, []
466
            if callable(limit_value):
467
                dynamic_limit = LimitGroup(
468
                    limit_value, func, _scope, per_method, methods,
469
                    error_message, exempt_when
470
                )
471
            else:
472
                try:
473
                    static_limits = list(
474
                        LimitGroup(
475
                            limit_value, func, _scope, per_method, methods,
476
                            error_message, exempt_when
477
                        )
478
                    )
479
                except ValueError as e:
480
                    self.logger.error(
481
                        "failed to configure %s %s (%s)", "view function"
482
                        if is_route else "blueprint", name, e
483
                    )
484
            if isinstance(obj, Blueprint):
485
                if dynamic_limit:
486
                    self._blueprint_dynamic_limits.setdefault(
487
                        name, []
488
                    ).append(dynamic_limit)
489
                else:
490
                    self._blueprint_limits.setdefault(
491
                        name, []
492
                    ).extend(static_limits)
493
            else:
494
495
                @wraps(obj)
496
                def __inner(*a, **k):
497
                    return obj(*a, **k)
498
499
                if dynamic_limit:
500
                    self._dynamic_route_limits.setdefault(
501
                        name, []
502
                    ).append(dynamic_limit)
503
                else:
504
                    self._route_limits.setdefault(
505
                        name, []
506
                    ).extend(static_limits)
507
                return __inner
508
509
        return _inner
510
511
    def limit(
512
        self,
513
        limit_value,
514
        key_func=None,
515
        per_method=False,
516
        methods=None,
517
        error_message=None,
518
        exempt_when=None
519
    ):
520
        """
521
        decorator to be used for rate limiting individual routes or blueprints.
522
523
        :param limit_value: rate limit string or a callable that returns a string.
524
         :ref:`ratelimit-string` for more details.
525
        :param function key_func: function/lambda to extract the unique identifier for
526
         the rate limit. defaults to remote address of the request.
527
        :param bool per_method: whether the limit is sub categorized into the http
528
         method of the request.
529
        :param list methods: if specified, only the methods in this list will be rate
530
         limited (default: None).
531
        :param error_message: string (or callable that returns one) to override the
532
         error message used in the response.
533
        :return:
534
        """
535
        return self.__limit_decorator(
536
            limit_value,
537
            key_func,
538
            per_method=per_method,
539
            methods=methods,
540
            error_message=error_message,
541
            exempt_when=exempt_when
542
        )
543
544
    def shared_limit(
545
        self,
546
        limit_value,
547
        scope,
548
        key_func=None,
549
        error_message=None,
550
        exempt_when=None
551
    ):
552
        """
553
        decorator to be applied to multiple routes sharing the same rate limit.
554
555
        :param limit_value: rate limit string or a callable that returns a string.
556
         :ref:`ratelimit-string` for more details.
557
        :param scope: a string or callable that returns a string
558
         for defining the rate limiting scope.
559
        :param function key_func: function/lambda to extract the unique identifier for
560
         the rate limit. defaults to remote address of the request.
561
        :param error_message: string (or callable that returns one) to override the
562
         error message used in the response.
563
        """
564
        return self.__limit_decorator(
565
            limit_value,
566
            key_func,
567
            True,
568
            scope,
569
            error_message=error_message,
570
            exempt_when=exempt_when
571
        )
572
573
    def exempt(self, obj):
574
        """
575
        decorator to mark a view or all views in a blueprint as exempt from rate limits.
576
        """
577
        if not isinstance(obj, Blueprint):
578
            name = "%s.%s" % (obj.__module__, obj.__name__)
579
580
            @wraps(obj)
581
            def __inner(*a, **k):
582
                return obj(*a, **k)
583
584
            self._exempt_routes.add(name)
585
            return __inner
586
        else:
587
            self._blueprint_exempt.add(obj.name)
588
589
    def request_filter(self, fn):
590
        """
591
        decorator to mark a function as a filter to be executed
592
        to check if the request is exempt from rate limiting.
593
        """
594
        self._request_filters.append(fn)
595
        return fn
596
597
    def raise_global_limits_warning(self):
598
        warnings.warn(
599
            "global_limits was a badly name configuration since it is actually a default limit and not a "
600
            " globally shared limit. Use default_limits if you want to provide a default or use application_limits "
601
            " if you intend to really have a global shared limit", UserWarning
602
        )
603