Completed
Push — master ( 7140d7...45c1cc )
by Ali-Akber
32s
created

Limiter.__evaluate_limits()   F

Complexity

Conditions 14

Size

Total Lines 44

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 14
c 0
b 0
f 0
dl 0
loc 44
rs 2.7581

How to fix   Complexity   

Complexity

Complex classes like Limiter.__evaluate_limits() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
the flask extension
3
"""
4
import itertools
5
import logging
6
import sys
7
import time
8
import warnings
9
from functools import wraps
10
11
import six
12
from flask import request, current_app, g, Blueprint
13
from limits.errors import ConfigurationError
14
from limits.storage import storage_from_string, MemoryStorage
15
from limits.strategies import STRATEGIES
16
from werkzeug.http import http_date
17
18
from flask_limiter.wrappers import Limit, LimitGroup
19
from .errors import RateLimitExceeded
20
from .util import get_ipaddr
21
22
23
class C:
24
    ENABLED = "RATELIMIT_ENABLED"
25
    HEADERS_ENABLED = "RATELIMIT_HEADERS_ENABLED"
26
    STORAGE_URL = "RATELIMIT_STORAGE_URL"
27
    STORAGE_OPTIONS = "RATELIMIT_STORAGE_OPTIONS"
28
    STRATEGY = "RATELIMIT_STRATEGY"
29
    GLOBAL_LIMITS = "RATELIMIT_GLOBAL"
30
    DEFAULT_LIMITS = "RATELIMIT_DEFAULT"
31
    APPLICATION_LIMITS = "RATELIMIT_APPLICATION"
32
    HEADER_LIMIT = "RATELIMIT_HEADER_LIMIT"
33
    HEADER_REMAINING = "RATELIMIT_HEADER_REMAINING"
34
    HEADER_RESET = "RATELIMIT_HEADER_RESET"
35
    SWALLOW_ERRORS = "RATELIMIT_SWALLOW_ERRORS"
36
    IN_MEMORY_FALLBACK = "RATELIMIT_IN_MEMORY_FALLBACK"
37
    HEADER_RETRY_AFTER = "RATELIMIT_HEADER_RETRY_AFTER"
38
    HEADER_RETRY_AFTER_VALUE = "RATELIMIT_HEADER_RETRY_AFTER_VALUE"
39
    KEY_PREFIX = "RATELIMIT_KEY_PREFIX"
40
41
42
class HEADERS:
43
    RESET = 1
44
    REMAINING = 2
45
    LIMIT = 3
46
    RETRY_AFTER = 4
47
48
49
MAX_BACKEND_CHECKS = 5
50
51
52
class Limiter(object):
53
    """
54
    :param app: :class:`flask.Flask` instance to initialize the extension
55
     with.
56
    :param list default_limits: a variable list of strings or callables returning strings denoting global
57
     limits to apply to all routes. :ref:`ratelimit-string` for  more details.
58
    :param list application_limits: a variable list of strings or callables returning strings for limits that
59
     are applied to the entire application (i.e a shared limit for all routes)
60
    :param function key_func: a callable that returns the domain to rate limit by.
61
    :param bool headers_enabled: whether ``X-RateLimit`` response headers are written.
62
    :param str strategy: the strategy to use. refer to :ref:`ratelimit-strategy`
63
    :param str storage_uri: the storage location. refer to :ref:`ratelimit-conf`
64
    :param dict storage_options: kwargs to pass to the storage implementation upon
65
      instantiation.
66
    :param bool auto_check: whether to automatically check the rate limit in the before_request
67
     chain of the application. default ``True``
68
    :param bool swallow_errors: whether to swallow errors when hitting a rate limit.
69
     An exception will still be logged. default ``False``
70
    :param list in_memory_fallback: a variable list of strings or callables returning strings denoting fallback
71
     limits to apply when the storage is down.
72
    :param str key_prefix: prefix prepended to rate limiter keys.
73
    """
74
75
    def __init__(
76
        self,
77
        app=None,
78
        key_func=None,
79
        global_limits=[],
80
        default_limits=[],
81
        application_limits=[],
82
        headers_enabled=False,
83
        strategy=None,
84
        storage_uri=None,
85
        storage_options={},
86
        auto_check=True,
87
        swallow_errors=False,
88
        in_memory_fallback=[],
89
        retry_after=None,
90
        key_prefix=""
91
    ):
92
        self.app = app
93
        self.logger = logging.getLogger("flask-limiter")
94
95
        self.enabled = True
96
        self._default_limits = []
97
        self._application_limits = []
98
        self._in_memory_fallback = []
99
        self._exempt_routes = set()
100
        self._request_filters = []
101
        self._headers_enabled = headers_enabled
102
        self._header_mapping = {}
103
        self._retry_after = retry_after
104
        self._strategy = strategy
105
        self._storage_uri = storage_uri
106
        self._storage_options = storage_options
107
        self._auto_check = auto_check
108
        self._swallow_errors = swallow_errors
109
        if not key_func:
110
            warnings.warn(
111
                "Use of the default `get_ipaddr` function is discouraged."
112
                " Please refer to https://flask-limiter.readthedocs.org/#rate-limit-domain"
113
                " for the recommended configuration", UserWarning
114
            )
115
        if global_limits:
116
            self.raise_global_limits_warning()
117
118
        self._key_func = key_func or get_ipaddr
119
        self._key_prefix = key_prefix
120
121
        for limit in set(global_limits + default_limits):
122
            self._default_limits.extend(
123
                [
124
                    LimitGroup(
125
                        limit, self._key_func, None, False, None, None, None
126
                    )
127
                ]
128
            )
129
        for limit in application_limits:
130
            self._application_limits.extend(
131
                [
132
                    LimitGroup(
133
                        limit, self._key_func, "global", False, None, None,
134
                        None
135
                    )
136
                ]
137
            )
138
        for limit in in_memory_fallback:
139
            self._in_memory_fallback.extend(
140
                [
141
                    LimitGroup(
142
                        limit, self._key_func, None, False, None, None, None
143
                    )
144
                ]
145
            )
146
        self._route_limits = {}
147
        self._dynamic_route_limits = {}
148
        self._blueprint_limits = {}
149
        self._blueprint_dynamic_limits = {}
150
        self._blueprint_exempt = set()
151
        self._storage = self._limiter = None
152
        self._storage_dead = False
153
        self._fallback_limiter = None
154
        self.__check_backend_count = 0
155
        self.__last_check_backend = time.time()
156
        self.__marked_for_limiting = {}
157
158
        class BlackHoleHandler(logging.StreamHandler):
159
            def emit(*_):
160
                return
161
162
        self.logger.addHandler(BlackHoleHandler())
163
        if app:
164
            self.init_app(app)
165
166
    def init_app(self, app):
167
        """
168
        :param app: :class:`flask.Flask` instance to rate limit.
169
        """
170
        self.enabled = app.config.setdefault(C.ENABLED, True)
171
        self._swallow_errors = app.config.setdefault(
172
            C.SWALLOW_ERRORS, self._swallow_errors
173
        )
174
        self._headers_enabled = (
175
            self._headers_enabled
176
            or app.config.setdefault(C.HEADERS_ENABLED, False)
177
        )
178
        self._storage_options.update(app.config.get(C.STORAGE_OPTIONS, {}))
179
        self._storage = storage_from_string(
180
            self._storage_uri
181
            or app.config.setdefault(C.STORAGE_URL, 'memory://'),
182
            **self._storage_options
183
        )
184
        strategy = (
185
            self._strategy
186
            or app.config.setdefault(C.STRATEGY, 'fixed-window')
187
        )
188
        if strategy not in STRATEGIES:
189
            raise ConfigurationError(
190
                "Invalid rate limiting strategy %s" % strategy
191
            )
192
        self._limiter = STRATEGIES[strategy](self._storage)
193
        self._header_mapping.update(
194
            {
195
                HEADERS.RESET:
196
                self._header_mapping.get(HEADERS.RESET, None)
197
                or app.config.setdefault(C.HEADER_RESET, "X-RateLimit-Reset"),
198
                HEADERS.REMAINING:
199
                self._header_mapping.get(HEADERS.REMAINING, None)
200
                or app.config.setdefault(
201
                    C.HEADER_REMAINING, "X-RateLimit-Remaining"
202
                ),
203
                HEADERS.LIMIT:
204
                self._header_mapping.get(HEADERS.LIMIT, None)
205
                or app.config.setdefault(C.HEADER_LIMIT, "X-RateLimit-Limit"),
206
                HEADERS.RETRY_AFTER:
207
                self._header_mapping.get(HEADERS.RETRY_AFTER, None)
208
                or app.config.setdefault(C.HEADER_RETRY_AFTER, "Retry-After"),
209
            }
210
        )
211
        self._retry_after = (
212
            self._retry_after or app.config.get(C.HEADER_RETRY_AFTER_VALUE)
213
        )
214
        self._key_prefix = (self._key_prefix or app.config.get(C.KEY_PREFIX))
215
        app_limits = app.config.get(C.APPLICATION_LIMITS, None)
216
        if not self._application_limits and app_limits:
217
            self._application_limits = [
218
                LimitGroup(
219
                    app_limits, self._key_func, "global", False, None, None,
220
                    None
221
                )
222
            ]
223
224
        if app.config.get(C.GLOBAL_LIMITS, None):
225
            self.raise_global_limits_warning()
226
        conf_limits = app.config.get(
227
            C.GLOBAL_LIMITS, app.config.get(C.DEFAULT_LIMITS, None)
228
        )
229
        if not self._default_limits and conf_limits:
230
            self._default_limits = [
231
                LimitGroup(
232
                    conf_limits, self._key_func, None, False, None, None, None
233
                )
234
            ]
235
        fallback_limits = app.config.get(C.IN_MEMORY_FALLBACK, None)
236
        if not self._in_memory_fallback and fallback_limits:
237
            self._in_memory_fallback = [
238
                LimitGroup(
239
                    fallback_limits, self._key_func, None, False, None, None,
240
                    None
241
                )
242
            ]
243
        if self._auto_check:
244
            app.before_request(self.__check_request_limit)
245
        app.after_request(self.__inject_headers)
246
247
        if self._in_memory_fallback:
248
            self._fallback_storage = MemoryStorage()
249
            self._fallback_limiter = STRATEGIES[strategy](
250
                self._fallback_storage
251
            )
252
253
        # purely for backward compatibility as stated in flask documentation
254
        if not hasattr(app, 'extensions'):
255
            app.extensions = {}  # pragma: no cover
256
        app.extensions['limiter'] = self
257
258
    def __should_check_backend(self):
259
        if self.__check_backend_count > MAX_BACKEND_CHECKS:
260
            self.__check_backend_count = 0
261
        if time.time() - self.__last_check_backend > pow(
262
            2, self.__check_backend_count
263
        ):
264
            self.__last_check_backend = time.time()
265
            self.__check_backend_count += 1
266
            return True
267
        return False
268
269
    def check(self):
270
        """
271
        check the limits for the current request
272
273
        :raises: RateLimitExceeded
274
        """
275
        self.__check_request_limit(False)
276
277
    def reset(self):
278
        """
279
        resets the storage if it supports being reset
280
        """
281
        try:
282
            self._storage.reset()
283
            self.logger.info("Storage has been reset and all limits cleared")
284
        except NotImplementedError:
285
            self.logger.warning(
286
                "This storage type does not support being reset"
287
            )
288
289
    @property
290
    def limiter(self):
291
        if self._storage_dead and self._in_memory_fallback:
292
            return self._fallback_limiter
293
        else:
294
            return self._limiter
295
296
    def __inject_headers(self, response):
297
        current_limit = getattr(g, 'view_rate_limit', None)
298
        if self.enabled and self._headers_enabled and current_limit:
299
            window_stats = self.limiter.get_window_stats(*current_limit)
300
            reset_in = 1 + window_stats[0]
301
            response.headers.add(
302
                self._header_mapping[HEADERS.LIMIT],
303
                str(current_limit[0].amount)
304
            )
305
            response.headers.add(
306
                self._header_mapping[HEADERS.REMAINING], window_stats[1]
307
            )
308
            response.headers.add(self._header_mapping[HEADERS.RESET], reset_in)
309
            response.headers.add(
310
                self._header_mapping[HEADERS.RETRY_AFTER],
311
                self._retry_after == 'http-date' and http_date(reset_in)
312
                or int(reset_in - time.time())
313
            )
314
        return response
315
316
    def __evaluate_limits(self, endpoint, limits):
317
        failed_limit = None
318
        limit_for_header = None
319
        for lim in limits:
320
            limit_scope = lim.scope or endpoint
321
            if lim.is_exempt:
322
                return
323
            if lim.methods is not None and request.method.lower(
324
            ) not in lim.methods:
325
                return
326
            if lim.per_method:
327
                limit_scope += ":%s" % request.method
328
            limit_key = lim.key_func()
329
330
            args = [limit_key, limit_scope]
331
            if all(args):
332
                if self._key_prefix:
333
                    args = [self._key_prefix] + args
334
                if not limit_for_header or lim.limit < limit_for_header[0]:
335
                    limit_for_header = [lim.limit] + args
336
                if not self.limiter.hit(lim.limit, *args):
337
                    self.logger.warning(
338
                        "ratelimit %s (%s) exceeded at endpoint: %s",
339
                        lim.limit, limit_key, limit_scope
340
                    )
341
                    failed_limit = lim
342
                    limit_for_header = [lim.limit] + args
343
                    break
344
            else:
345
                self.logger.error(
346
                    "Skipping limit: %s. Empty value found in parameters.",
347
                    lim.limit
348
                )
349
                continue
350
        g.view_rate_limit = limit_for_header
351
352
        if failed_limit:
353
            if failed_limit.error_message:
354
                exc_description = failed_limit.error_message if not callable(
355
                    failed_limit.error_message
356
                ) else failed_limit.error_message()
357
            else:
358
                exc_description = six.text_type(failed_limit.limit)
359
            raise RateLimitExceeded(exc_description)
360
361
    def __check_request_limit(self, in_middleware=True):
362
        endpoint = request.endpoint or ""
363
        view_func = current_app.view_functions.get(endpoint, None)
364
        name = (
365
            "%s.%s" % (view_func.__module__, view_func.__name__)
366
            if view_func else ""
367
        )
368
        if (not request.endpoint
369
            or not self.enabled
370
            or view_func == current_app.send_static_file
371
            or name in self._exempt_routes
372
            or request.blueprint in self._blueprint_exempt
373
            or any(fn() for fn in self._request_filters)
374
            or g.get("_rate_limiting_complete")
375
        ):
376
            return
377
        limits, dynamic_limits = [], []
378
379
        # this is to ensure backward compatibility with behavior that
380
        # existed accidentally, i.e::
381
        #
382
        # @limiter.limit(...)
383
        # @app.route('...')
384
        # def func(...):
385
        #
386
        # The above setup would work in pre 1.0 versions because the decorator
387
        # was not acting immediately and instead simply registering the rate
388
        # limiting. The correct way to use the decorator is to wrap
389
        # the limiter with the route, i.e::
390
        #
391
        # @app.route(...)
392
        # @limiter.limit(...)
393
        # def func(...):
394
395
        implicit_decorator = view_func in self.__marked_for_limiting.get(
396
            name, []
397
        )
398
399
        if not in_middleware or implicit_decorator:
400
            limits = (
401
                name in self._route_limits and self._route_limits[name] or []
402
            )
403
            dynamic_limits = []
404
            if name in self._dynamic_route_limits:
405
                for lim in self._dynamic_route_limits[name]:
406
                    try:
407
                        dynamic_limits.extend(list(lim))
408
                    except ValueError as e:
409
                        self.logger.error(
410
                            "failed to load ratelimit for view function %s (%s)",
411
                            name, e
412
                        )
413
        if request.blueprint:
414
            if (request.blueprint in self._blueprint_dynamic_limits
415
                and not dynamic_limits
416
            ):
417
                for limit_group in self._blueprint_dynamic_limits[
418
                    request.blueprint
419
                ]:
420
                    try:
421
                        dynamic_limits.extend(
422
                            [
423
                                Limit(
424
                                    limit.limit, limit.key_func, limit.scope,
425
                                    limit.per_method, limit.methods,
426
                                    limit.error_message, limit.exempt_when
427
                                ) for limit in limit_group
428
                            ]
429
                        )
430
                    except ValueError as e:
431
                        self.logger.error(
432
                            "failed to load ratelimit for blueprint %s (%s)",
433
                            request.blueprint, e
434
                        )
435
            if request.blueprint in self._blueprint_limits and not limits:
436
                limits.extend(self._blueprint_limits[request.blueprint])
437
438
        try:
439
            all_limits = []
440
            if self._storage_dead and self._fallback_limiter:
441
                if in_middleware and name in self.__marked_for_limiting:
442
                    pass
443
                else:
444
                    if self.__should_check_backend() and self._storage.check():
445
                        self.logger.info("Rate limit storage recovered")
446
                        self._storage_dead = False
447
                        self.__check_backend_count = 0
448
                    else:
449
                        all_limits = list(
450
                            itertools.chain(*self._in_memory_fallback)
451
                        )
452
            if not all_limits:
453
                route_limits = limits + dynamic_limits
454
                all_limits = list(itertools.chain(*self._application_limits))
455
                all_limits += route_limits
456
                if (
457
                    not route_limits
458
                    and not (in_middleware and name in self.__marked_for_limiting)
459
                    or implicit_decorator
460
                ):
461
                        all_limits += list(itertools.chain(*self._default_limits))
462
            self.__evaluate_limits(endpoint, all_limits)
463
        except Exception as e:  # no qa
464
            if isinstance(e, RateLimitExceeded):
465
                six.reraise(*sys.exc_info())
466
            if self._in_memory_fallback and not self._storage_dead:
467
                self.logger.warn(
468
                    "Rate limit storage unreachable - falling back to"
469
                    " in-memory storage"
470
                )
471
                self._storage_dead = True
472
                self.__check_request_limit(in_middleware)
473
            else:
474
                if self._swallow_errors:
475
                    self.logger.exception(
476
                        "Failed to rate limit. Swallowing error"
477
                    )
478
                else:
479
                    six.reraise(*sys.exc_info())
480
481
    def __limit_decorator(
482
        self,
483
        limit_value,
484
        key_func=None,
485
        shared=False,
486
        scope=None,
487
        per_method=False,
488
        methods=None,
489
        error_message=None,
490
        exempt_when=None,
491
    ):
492
        _scope = scope if shared else None
493
494
        def _inner(obj):
495
            func = key_func or self._key_func
496
            is_route = not isinstance(obj, Blueprint)
497
            name = "%s.%s" % (
498
                obj.__module__, obj.__name__
499
            ) if is_route else obj.name
500
            dynamic_limit, static_limits = None, []
501
            if callable(limit_value):
502
                dynamic_limit = LimitGroup(
503
                    limit_value, func, _scope, per_method, methods,
504
                    error_message, exempt_when
505
                )
506
            else:
507
                try:
508
                    static_limits = list(
509
                        LimitGroup(
510
                            limit_value, func, _scope, per_method, methods,
511
                            error_message, exempt_when
512
                        )
513
                    )
514
                except ValueError as e:
515
                    self.logger.error(
516
                        "failed to configure %s %s (%s)", "view function"
517
                        if is_route else "blueprint", name, e
518
                    )
519
            if isinstance(obj, Blueprint):
520
                if dynamic_limit:
521
                    self._blueprint_dynamic_limits.setdefault(
522
                        name, []
523
                    ).append(dynamic_limit)
524
                else:
525
                    self._blueprint_limits.setdefault(
526
                        name, []
527
                    ).extend(static_limits)
528
            else:
529
                self.__marked_for_limiting.setdefault(name, []).append(obj)
530
                if dynamic_limit:
531
                    self._dynamic_route_limits.setdefault(
532
                        name, []
533
                    ).append(dynamic_limit)
534
                else:
535
                    self._route_limits.setdefault(
536
                        name, []
537
                    ).extend(static_limits)
538
539
                @wraps(obj)
540
                def __inner(*a, **k):
541
                    if self._auto_check and not g.get("_rate_limiting_complete"):
542
                        self.__check_request_limit(False)
543
                        g._rate_limiting_complete = True
544
                    return obj(*a, **k)
545
                return __inner
546
        return _inner
547
548
    def limit(
549
        self,
550
        limit_value,
551
        key_func=None,
552
        per_method=False,
553
        methods=None,
554
        error_message=None,
555
        exempt_when=None,
556
    ):
557
        """
558
        decorator to be used for rate limiting individual routes or blueprints.
559
560
        :param limit_value: rate limit string or a callable that returns a string.
561
         :ref:`ratelimit-string` for more details.
562
        :param function key_func: function/lambda to extract the unique identifier for
563
         the rate limit. defaults to remote address of the request.
564
        :param bool per_method: whether the limit is sub categorized into the http
565
         method of the request.
566
        :param list methods: if specified, only the methods in this list will be rate
567
         limited (default: None).
568
        :param error_message: string (or callable that returns one) to override the
569
         error message used in the response.
570
        :param exempt_when:
571
        :return:
572
        """
573
        return self.__limit_decorator(
574
            limit_value,
575
            key_func,
576
            per_method=per_method,
577
            methods=methods,
578
            error_message=error_message,
579
            exempt_when=exempt_when,
580
        )
581
582
    def shared_limit(
583
        self,
584
        limit_value,
585
        scope,
586
        key_func=None,
587
        error_message=None,
588
        exempt_when=None,
589
    ):
590
        """
591
        decorator to be applied to multiple routes sharing the same rate limit.
592
593
        :param limit_value: rate limit string or a callable that returns a string.
594
         :ref:`ratelimit-string` for more details.
595
        :param scope: a string or callable that returns a string
596
         for defining the rate limiting scope.
597
        :param function key_func: function/lambda to extract the unique identifier for
598
         the rate limit. defaults to remote address of the request.
599
        :param error_message: string (or callable that returns one) to override the
600
         error message used in the response.
601
        :param exempt_when:
602
        """
603
        return self.__limit_decorator(
604
            limit_value,
605
            key_func,
606
            True,
607
            scope,
608
            error_message=error_message,
609
            exempt_when=exempt_when,
610
        )
611
612
    def exempt(self, obj):
613
        """
614
        decorator to mark a view or all views in a blueprint as exempt from rate limits.
615
        """
616
        if not isinstance(obj, Blueprint):
617
            name = "%s.%s" % (obj.__module__, obj.__name__)
618
619
            @wraps(obj)
620
            def __inner(*a, **k):
621
                return obj(*a, **k)
622
623
            self._exempt_routes.add(name)
624
            return __inner
625
        else:
626
            self._blueprint_exempt.add(obj.name)
627
628
    def request_filter(self, fn):
629
        """
630
        decorator to mark a function as a filter to be executed
631
        to check if the request is exempt from rate limiting.
632
        """
633
        self._request_filters.append(fn)
634
        return fn
635
636
    def raise_global_limits_warning(self):
637
        warnings.warn(
638
            "global_limits was a badly name configuration since it is actually a default limit and not a "
639
            " globally shared limit. Use default_limits if you want to provide a default or use application_limits "
640
            " if you intend to really have a global shared limit", UserWarning
641
        )
642