Completed
Push — master ( efd80c...9e0f7a )
by Ali-Akber
29s
created

Limiter.shared_limit()   B

Complexity

Conditions 1

Size

Total Lines 27

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 1
c 1
b 0
f 0
dl 0
loc 27
rs 8.8571
1
"""
2
the flask extension
3
"""
4
import logging
5
import sys
6
import time
7
import warnings
8
from functools import wraps
9
10
import itertools
11
import six
12
from flask import request, current_app, g, Blueprint
13
from limits.errors import ConfigurationError
14
from limits.storage import storage_from_string, MemoryStorage
15
from limits.strategies import STRATEGIES
16
from limits.util import parse_many
17
from werkzeug.http import http_date
18
19
from flask_limiter.wrappers import Limit, LimitGroup
20
from .errors import RateLimitExceeded
21
from .util import get_ipaddr
22
23
24
class C:
25
    ENABLED = "RATELIMIT_ENABLED"
26
    HEADERS_ENABLED = "RATELIMIT_HEADERS_ENABLED"
27
    STORAGE_URL = "RATELIMIT_STORAGE_URL"
28
    STORAGE_OPTIONS = "RATELIMIT_STORAGE_OPTIONS"
29
    STRATEGY = "RATELIMIT_STRATEGY"
30
    GLOBAL_LIMITS = "RATELIMIT_GLOBAL"
31
    DEFAULT_LIMITS = "RATELIMIT_DEFAULT"
32
    APPLICATION_LIMITS = "RATELIMIT_APPLICATION"
33
    HEADER_LIMIT = "RATELIMIT_HEADER_LIMIT"
34
    HEADER_REMAINING = "RATELIMIT_HEADER_REMAINING"
35
    HEADER_RESET = "RATELIMIT_HEADER_RESET"
36
    SWALLOW_ERRORS = "RATELIMIT_SWALLOW_ERRORS"
37
    IN_MEMORY_FALLBACK = "RATELIMIT_IN_MEMORY_FALLBACK"
38
    HEADER_RETRY_AFTER = "RATELIMIT_HEADER_RETRY_AFTER"
39
    HEADER_RETRY_AFTER_VALUE = "RATELIMIT_HEADER_RETRY_AFTER_VALUE"
40
    KEY_PREFIX = "RATELIMIT_KEY_PREFIX"
41
42
43
class HEADERS:
44
    RESET = 1
45
    REMAINING = 2
46
    LIMIT = 3
47
    RETRY_AFTER = 4
48
49
50
MAX_BACKEND_CHECKS = 5
51
52
53
class Limiter(object):
54
    """
55
    :param app: :class:`flask.Flask` instance to initialize the extension
56
     with.
57
    :param list default_limits: a variable list of strings denoting global
58
     limits to apply to all routes. :ref:`ratelimit-string` for  more details.
59
    :param list application_limits: a variable list of strings for limits that
60
     are applied to the entire application (i.e a shared limit for all routes)
61
    :param function key_func: a callable that returns the domain to rate limit by.
62
    :param bool headers_enabled: whether ``X-RateLimit`` response headers are written.
63
    :param str strategy: the strategy to use. refer to :ref:`ratelimit-strategy`
64
    :param str storage_uri: the storage location. refer to :ref:`ratelimit-conf`
65
    :param dict storage_options: kwargs to pass to the storage implementation upon
66
      instantiation.
67
    :param bool auto_check: whether to automatically check the rate limit in the before_request
68
     chain of the application. default ``True``
69
    :param bool swallow_errors: whether to swallow errors when hitting a rate limit.
70
     An exception will still be logged. default ``False``
71
    :param list in_memory_fallback: a variable list of strings denoting fallback
72
     limits to apply when the storage is down.
73
    :param str key_prefix: prefix prepended to rate limiter keys.
74
    """
75
76
    def __init__(
77
        self,
78
        app=None,
79
        key_func=None,
80
        global_limits=[],
81
        default_limits=[],
82
        application_limits=[],
83
        headers_enabled=False,
84
        strategy=None,
85
        storage_uri=None,
86
        storage_options={},
87
        auto_check=True,
88
        swallow_errors=False,
89
        in_memory_fallback=[],
90
        retry_after=None,
91
        key_prefix=""
92
    ):
93
        self.app = app
94
        self.logger = logging.getLogger("flask-limiter")
95
96
        self.enabled = True
97
        self._default_limits = []
98
        self._application_limits = []
99
        self._in_memory_fallback = []
100
        self._exempt_routes = set()
101
        self._request_filters = []
102
        self._headers_enabled = headers_enabled
103
        self._header_mapping = {}
104
        self._retry_after = retry_after
105
        self._strategy = strategy
106
        self._storage_uri = storage_uri
107
        self._storage_options = storage_options
108
        self._auto_check = auto_check
109
        self._swallow_errors = swallow_errors
110
        if not key_func:
111
            warnings.warn(
112
                "Use of the default `get_ipaddr` function is discouraged."
113
                " Please refer to https://flask-limiter.readthedocs.org/#rate-limit-domain"
114
                " for the recommended configuration", UserWarning
115
            )
116
        if global_limits:
117
            self.raise_global_limits_warning()
118
119
        self._key_func = key_func or get_ipaddr
120
        self._key_prefix = key_prefix
121
122
        for limit in set(global_limits + default_limits):
123
            self._default_limits.extend([
124
                LimitGroup(
125
                    limit, self._key_func, None, False, None, None, None
126
                )
127
            ])
128
        for limit in application_limits:
129
            self._application_limits.extend([
130
                LimitGroup(
131
                    limit, self._key_func, "global", False, None, None, None
132
                )
133
            ])
134
        for limit in in_memory_fallback:
135
            self._in_memory_fallback.extend([
136
                LimitGroup(
137
                    limit, self._key_func, None, False, None, None, None
138
                )
139
            ])
140
        self._route_limits = {}
141
        self._dynamic_route_limits = {}
142
        self._blueprint_limits = {}
143
        self._blueprint_dynamic_limits = {}
144
        self._blueprint_exempt = set()
145
        self._storage = self._limiter = None
146
        self._storage_dead = False
147
        self._fallback_limiter = None
148
        self.__check_backend_count = 0
149
        self.__last_check_backend = time.time()
150
151
        class BlackHoleHandler(logging.StreamHandler):
152
            def emit(*_):
153
                return
154
155
        self.logger.addHandler(BlackHoleHandler())
156
        if app:
157
            self.init_app(app)
158
159
    def init_app(self, app):
160
        """
161
        :param app: :class:`flask.Flask` instance to rate limit.
162
        """
163
        self.enabled = app.config.setdefault(C.ENABLED, True)
164
        self._swallow_errors = app.config.setdefault(
165
            C.SWALLOW_ERRORS, self._swallow_errors
166
        )
167
        self._headers_enabled = (
168
            self._headers_enabled
169
            or app.config.setdefault(C.HEADERS_ENABLED, False)
170
        )
171
        self._storage_options.update(app.config.get(C.STORAGE_OPTIONS, {}))
172
        self._storage = storage_from_string(
173
            self._storage_uri
174
            or app.config.setdefault(C.STORAGE_URL, 'memory://'),
175
            **self._storage_options
176
        )
177
        strategy = (
178
            self._strategy
179
            or app.config.setdefault(C.STRATEGY, 'fixed-window')
180
        )
181
        if strategy not in STRATEGIES:
182
            raise ConfigurationError(
183
                "Invalid rate limiting strategy %s" % strategy
184
            )
185
        self._limiter = STRATEGIES[strategy](self._storage)
186
        self._header_mapping.update({
187
            HEADERS.RESET:
188
            self._header_mapping.get(HEADERS.RESET, None)
189
            or app.config.setdefault(C.HEADER_RESET, "X-RateLimit-Reset"),
190
            HEADERS.REMAINING:
191
            self._header_mapping.get(HEADERS.REMAINING, None) or
192
            app.config.setdefault(C.HEADER_REMAINING, "X-RateLimit-Remaining"),
193
            HEADERS.LIMIT:
194
            self._header_mapping.get(HEADERS.LIMIT, None)
195
            or app.config.setdefault(C.HEADER_LIMIT, "X-RateLimit-Limit"),
196
            HEADERS.RETRY_AFTER:
197
            self._header_mapping.get(HEADERS.RETRY_AFTER, None)
198
            or app.config.setdefault(C.HEADER_RETRY_AFTER, "Retry-After"),
199
        })
200
        self._retry_after = (
201
            self._retry_after or app.config.get(C.HEADER_RETRY_AFTER_VALUE)
202
        )
203
        self._key_prefix = (self._key_prefix or app.config.get(C.KEY_PREFIX))
204
        app_limits = app.config.get(C.APPLICATION_LIMITS, None)
205
        if not self._application_limits and app_limits:
206
            self._application_limits = [
207
                LimitGroup(
208
                    app_limits, self._key_func, "global", False, None, None,
209
                    None
210
                )
211
            ]
212
213
        if app.config.get(C.GLOBAL_LIMITS, None):
214
            self.raise_global_limits_warning()
215
        conf_limits = app.config.get(
216
            C.GLOBAL_LIMITS, app.config.get(C.DEFAULT_LIMITS, None)
217
        )
218
        if not self._default_limits and conf_limits:
219
            self._default_limits = [
220
                LimitGroup(
221
                    conf_limits, self._key_func, None, False, None, None, None
222
                )
223
            ]
224
        fallback_limits = app.config.get(C.IN_MEMORY_FALLBACK, None)
225
        if not self._in_memory_fallback and fallback_limits:
226
            self._in_memory_fallback = [
227
                LimitGroup(
228
                    fallback_limits, self._key_func, None, False, None, None,
229
                    None
230
                )
231
            ]
232
        if self._auto_check:
233
            app.before_request(self.__check_request_limit)
234
        app.after_request(self.__inject_headers)
235
236
        if self._in_memory_fallback:
237
            self._fallback_storage = MemoryStorage()
238
            self._fallback_limiter = STRATEGIES[strategy](
239
                self._fallback_storage
240
            )
241
242
        # purely for backward compatibility as stated in flask documentation
243
        if not hasattr(app, 'extensions'):
244
            app.extensions = {}  # pragma: no cover
245
        app.extensions['limiter'] = self
246
247
    def __should_check_backend(self):
248
        if self.__check_backend_count > MAX_BACKEND_CHECKS:
249
            self.__check_backend_count = 0
250
        if time.time() - self.__last_check_backend > pow(
251
            2, self.__check_backend_count
252
        ):
253
            self.__last_check_backend = time.time()
254
            self.__check_backend_count += 1
255
            return True
256
        return False
257
258
    def check(self):
259
        """
260
        check the limits for the current request
261
262
        :raises: RateLimitExceeded
263
        """
264
        self.__check_request_limit()
265
266
    def reset(self):
267
        """
268
        resets the storage if it supports being reset
269
        """
270
        try:
271
            self._storage.reset()
272
            self.logger.info("Storage has been reset and all limits cleared")
273
        except NotImplementedError:
274
            self.logger.warning(
275
                "This storage type does not support being reset"
276
            )
277
278
    @property
279
    def limiter(self):
280
        if self._storage_dead and self._in_memory_fallback:
281
            return self._fallback_limiter
282
        else:
283
            return self._limiter
284
285
    def __inject_headers(self, response):
286
        current_limit = getattr(g, 'view_rate_limit', None)
287
        if self.enabled and self._headers_enabled and current_limit:
288
            window_stats = self.limiter.get_window_stats(*current_limit)
289
            reset_in = 1 + window_stats[0]
290
            response.headers.add(
291
                self._header_mapping[HEADERS.LIMIT],
292
                str(current_limit[0].amount)
293
            )
294
            response.headers.add(
295
                self._header_mapping[HEADERS.REMAINING], window_stats[1]
296
            )
297
            response.headers.add(self._header_mapping[HEADERS.RESET], reset_in)
298
            response.headers.add(
299
                self._header_mapping[HEADERS.RETRY_AFTER],
300
                self._retry_after == 'http-date' and http_date(reset_in)
301
                or int(reset_in - time.time())
302
            )
303
        return response
304
305
    def __check_request_limit(self):
306
        endpoint = request.endpoint or ""
307
        view_func = current_app.view_functions.get(endpoint, None)
308
        name = (
309
            "%s.%s" % (view_func.__module__, view_func.__name__)
310
            if view_func else ""
311
        )
312
        if (not request.endpoint
313
            or not self.enabled
314
            or view_func == current_app.send_static_file
315
            or name in self._exempt_routes
316
            or request.blueprint in self._blueprint_exempt
317
            or any(fn() for fn in self._request_filters)
318
        ):
319
            return
320
        limits = (
321
            name in self._route_limits and self._route_limits[name] or []
322
        )
323
        dynamic_limits = []
324
        if name in self._dynamic_route_limits:
325
            for lim in self._dynamic_route_limits[name]:
326
                try:
327
                    dynamic_limits.extend(list(lim))
328
                except ValueError as e:
329
                    self.logger.error(
330
                        "failed to load ratelimit for view function %s (%s)",
331
                        name, e
332
                    )
333
        if request.blueprint:
334
            if (request.blueprint in self._blueprint_dynamic_limits
335
                and not dynamic_limits
336
            ):
337
                for limit_group in self._blueprint_dynamic_limits[
338
                    request.blueprint
339
                ]:
340
                    try:
341
                        dynamic_limits.extend([
342
                            Limit(
343
                                limit.limit, limit.key_func, limit.scope,
344
                                limit.per_method, limit.methods,
345
                                limit.error_message, limit.exempt_when
346
                            ) for limit in limit_group
347
                        ])
348
                    except ValueError as e:
349
                        self.logger.error(
350
                            "failed to load ratelimit for blueprint %s (%s)",
351
                            request.blueprint, e
352
                        )
353
            if (request.blueprint in self._blueprint_limits and not limits):
354
                limits.extend(self._blueprint_limits[request.blueprint])
355
356
        failed_limit = None
357
        limit_for_header = None
358
        try:
359
            all_limits = []
360
            if self._storage_dead and self._fallback_limiter:
361
                if self.__should_check_backend() and self._storage.check():
362
                    self.logger.info("Rate limit storage recovered")
363
                    self._storage_dead = False
364
                    self.__check_backend_count = 0
365
                else:
366
                    all_limits = list(
367
                        itertools.chain(*self._in_memory_fallback)
368
                    )
369
            if not all_limits:
370
                all_limits = itertools.chain(
371
                    itertools.chain(*self._application_limits),
372
                    (limits + dynamic_limits)
373
                    or itertools.chain(*self._default_limits)
374
                )
375
            for lim in all_limits:
376
                limit_scope = lim.scope or endpoint
377
                if lim.is_exempt:
378
                    return
379
                if lim.methods is not None and request.method.lower(
380
                ) not in lim.methods:
381
                    return
382
                if lim.per_method:
383
                    limit_scope += ":%s" % request.method
384
                limit_key = lim.key_func()
385
386
                args = [limit_key, limit_scope]
387
                if all(args):
388
                    if not limit_for_header or lim.limit < limit_for_header[0]:
389
                        limit_for_header = (lim.limit, limit_key, limit_scope)
390
                    if self._key_prefix:
391
                        args = [self._key_prefix] + args
392
                    if not self.limiter.hit(lim.limit, *args):
393
                        self.logger.warning(
394
                            "ratelimit %s (%s) exceeded at endpoint: %s",
395
                            lim.limit, limit_key, limit_scope
396
                        )
397
                        failed_limit = lim
398
                        limit_for_header = (lim.limit, limit_key, limit_scope)
399
                        break
400
                else:
401
                    self.logger.error(
402
                        "Skipping limit: %s. Empty value found in parameters.",
403
                        lim.limit
404
                    )
405
                    continue
406
            g.view_rate_limit = limit_for_header
407
408
            if failed_limit:
409
                if failed_limit.error_message:
410
                    exc_description = failed_limit.error_message if not callable(
411
                        failed_limit.error_message
412
                    ) else failed_limit.error_message()
413
                else:
414
                    exc_description = six.text_type(failed_limit.limit)
415
                raise RateLimitExceeded(exc_description)
416
        except Exception as e:  # no qa
417
            if isinstance(e, RateLimitExceeded):
418
                six.reraise(*sys.exc_info())
419
            if self._in_memory_fallback and not self._storage_dead:
420
                self.logger.warn(
421
                    "Rate limit storage unreachable - falling back to"
422
                    " in-memory storage"
423
                )
424
                self._storage_dead = True
425
                self.__check_request_limit()
426
            else:
427
                if self._swallow_errors:
428
                    self.logger.exception(
429
                        "Failed to rate limit. Swallowing error"
430
                    )
431
                else:
432
                    six.reraise(*sys.exc_info())
433
434
    def __limit_decorator(
435
        self,
436
        limit_value,
437
        key_func=None,
438
        shared=False,
439
        scope=None,
440
        per_method=False,
441
        methods=None,
442
        error_message=None,
443
        exempt_when=None
444
    ):
445
        _scope = scope if shared else None
446
447
        def _inner(obj):
448
            func = key_func or self._key_func
449
            is_route = not isinstance(obj, Blueprint)
450
            name = "%s.%s" % (
451
                obj.__module__, obj.__name__
452
            ) if is_route else obj.name
453
            dynamic_limit, static_limits = None, []
454
            if callable(limit_value):
455
                dynamic_limit = LimitGroup(
456
                    limit_value, func, _scope, per_method, methods,
457
                    error_message, exempt_when
458
                )
459
            else:
460
                try:
461
                    static_limits = list(
462
                        LimitGroup(
463
                            limit_value, func, _scope, per_method, methods,
464
                            error_message, exempt_when
465
                        )
466
                    )
467
                except ValueError as e:
468
                    self.logger.error(
469
                        "failed to configure %s %s (%s)", "view function"
470
                        if is_route else "blueprint", name, e
471
                    )
472
            if isinstance(obj, Blueprint):
473
                if dynamic_limit:
474
                    self._blueprint_dynamic_limits.setdefault(
475
                        name, []
476
                    ).append(dynamic_limit)
477
                else:
478
                    self._blueprint_limits.setdefault(name, []
479
                                                      ).extend(static_limits)
480
            else:
481
482
                @wraps(obj)
483
                def __inner(*a, **k):
484
                    return obj(*a, **k)
485
486
                if dynamic_limit:
487
                    self._dynamic_route_limits.setdefault(
488
                        name, []
489
                    ).append(dynamic_limit)
490
                else:
491
                    self._route_limits.setdefault(name,
492
                                                  []).extend(static_limits)
493
                return __inner
494
495
        return _inner
496
497
    def limit(
498
        self,
499
        limit_value,
500
        key_func=None,
501
        per_method=False,
502
        methods=None,
503
        error_message=None,
504
        exempt_when=None
505
    ):
506
        """
507
        decorator to be used for rate limiting individual routes or blueprints.
508
509
        :param limit_value: rate limit string or a callable that returns a string.
510
         :ref:`ratelimit-string` for more details.
511
        :param function key_func: function/lambda to extract the unique identifier for
512
         the rate limit. defaults to remote address of the request.
513
        :param bool per_method: whether the limit is sub categorized into the http
514
         method of the request.
515
        :param list methods: if specified, only the methods in this list will be rate
516
         limited (default: None).
517
        :param error_message: string (or callable that returns one) to override the
518
         error message used in the response.
519
        :return:
520
        """
521
        return self.__limit_decorator(
522
            limit_value,
523
            key_func,
524
            per_method=per_method,
525
            methods=methods,
526
            error_message=error_message,
527
            exempt_when=exempt_when
528
        )
529
530
    def shared_limit(
531
        self,
532
        limit_value,
533
        scope,
534
        key_func=None,
535
        error_message=None,
536
        exempt_when=None
537
    ):
538
        """
539
        decorator to be applied to multiple routes sharing the same rate limit.
540
541
        :param limit_value: rate limit string or a callable that returns a string.
542
         :ref:`ratelimit-string` for more details.
543
        :param scope: a string or callable that returns a string
544
         for defining the rate limiting scope.
545
        :param function key_func: function/lambda to extract the unique identifier for
546
         the rate limit. defaults to remote address of the request.
547
        :param error_message: string (or callable that returns one) to override the
548
         error message used in the response.
549
        """
550
        return self.__limit_decorator(
551
            limit_value,
552
            key_func,
553
            True,
554
            scope,
555
            error_message=error_message,
556
            exempt_when=exempt_when
557
        )
558
559
    def exempt(self, obj):
560
        """
561
        decorator to mark a view or all views in a blueprint as exempt from rate limits.
562
        """
563
        if not isinstance(obj, Blueprint):
564
            name = "%s.%s" % (obj.__module__, obj.__name__)
565
566
            @wraps(obj)
567
            def __inner(*a, **k):
568
                return obj(*a, **k)
569
570
            self._exempt_routes.add(name)
571
            return __inner
572
        else:
573
            self._blueprint_exempt.add(obj.name)
574
575
    def request_filter(self, fn):
576
        """
577
        decorator to mark a function as a filter to be executed
578
        to check if the request is exempt from rate limiting.
579
        """
580
        self._request_filters.append(fn)
581
        return fn
582
583
    def raise_global_limits_warning(self):
584
        warnings.warn(
585
            "global_limits was a badly name configuration since it is actually a default limit and not a "
586
            " globally shared limit. Use default_limits if you want to provide a default or use application_limits "
587
            " if you intend to really have a global shared limit", UserWarning
588
        )
589