Completed
Push — master ( 750add...37937f )
by Ali-Akber
01:05
created

  F

Complexity

Total Complexity 99

Size/Duplication

Total Lines 480
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
c 0
b 0
f 0
dl 0
loc 480
wmc 99
rs 1.5789

17 Methods

Rating   Name   Duplication   Size   Complexity  
F imiter.__check_request_limit() 0 117 42
A imiter.limit() 0 20 1
A imiter.__should_check_backend() 0 8 3
F imiter.__limit_decorator() 0 51 12
A imiter.limiter() 0 6 3
A imiter.raise_global_limits_warning() 0 6 1
F imiter.init_app() 0 74 15
A imiter.request_filter() 0 7 1
A lackHoleHandler.emit() 0 2 1
F imiter.__init__() 0 83 10
A imiter.reset() 0 9 2
A imiter.shared_limit() 0 17 1
F imiter._inner() 0 41 10
A imiter.exempt() 0 13 3
A imiter.__inner() 0 3 1
B imiter.__inject_headers() 0 22 4
A imiter.check() 0 7 1

How to fix   Complexity   

Complex Class

Complex classes like often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""
2
the flask extension
3
"""
4
import warnings
5
from functools import wraps
6
import logging
7
8
from flask import request, current_app, g, Blueprint
9
from werkzeug.http import http_date
10
11
from limits.errors import ConfigurationError
12
from limits.storage import storage_from_string, MemoryStorage
13
from limits.strategies import STRATEGIES
14
from limits.util import parse_many
15
import six
16
import sys
17
import time
18
from .errors import RateLimitExceeded
19
from .util import get_ipaddr
20
21
class C:
22
    ENABLED = "RATELIMIT_ENABLED"
23
    HEADERS_ENABLED = "RATELIMIT_HEADERS_ENABLED"
24
    STORAGE_URL = "RATELIMIT_STORAGE_URL"
25
    STORAGE_OPTIONS = "RATELIMIT_STORAGE_OPTIONS"
26
    STRATEGY = "RATELIMIT_STRATEGY"
27
    GLOBAL_LIMITS = "RATELIMIT_GLOBAL"
28
    DEFAULT_LIMITS = "RATELIMIT_DEFAULT"
29
    APPLICATION_LIMITS = "RATELIMIT_APPLICATION"
30
    HEADER_LIMIT = "RATELIMIT_HEADER_LIMIT"
31
    HEADER_REMAINING = "RATELIMIT_HEADER_REMAINING"
32
    HEADER_RESET = "RATELIMIT_HEADER_RESET"
33
    SWALLOW_ERRORS = "RATELIMIT_SWALLOW_ERRORS"
34
    IN_MEMORY_FALLBACK = "RATELIMIT_IN_MEMORY_FALLBACK"
35
    HEADER_RETRY_AFTER = "RATELIMIT_HEADER_RETRY_AFTER"
36
    HEADER_RETRY_AFTER_VALUE = "RATELIMIT_HEADER_RETRY_AFTER_VALUE"
37
38
class HEADERS:
39
    RESET = 1
40
    REMAINING = 2
41
    LIMIT = 3
42
    RETRY_AFTER = 4
43
44
MAX_BACKEND_CHECKS = 5
45
46
class ExtLimit(object):
47
    """
48
    simple wrapper to encapsulate limits and their context
49
    """
50
    def __init__(self, limit, key_func, scope, per_method, methods, error_message,
51
                 exempt_when):
52
        self._limit = limit
53
        self.key_func = key_func
54
        self._scope = scope
55
        self.per_method = per_method
56
        self.methods = methods and [m.lower() for m in methods] or methods
57
        self.error_message = error_message
58
        self.exempt_when = exempt_when
59
60
    @property
61
    def limit(self):
62
        return self._limit() if callable(self._limit) else self._limit
63
64
    @property
65
    def scope(self):
66
        return self._scope(request.endpoint) if callable(self._scope) else self._scope
67
68
    @property
69
    def is_exempt(self):
70
        """Check if the limit is exempt."""
71
        return self.exempt_when and self.exempt_when()
72
73
class Limiter(object):
74
    """
75
    :param app: :class:`flask.Flask` instance to initialize the extension
76
     with.
77
    :param list default_limits: a variable list of strings denoting global
78
     limits to apply to all routes. :ref:`ratelimit-string` for  more details.
79
    :param list application_limits: a variable list of strings for limits that
80
     are applied to the entire application (i.e a shared limit for all routes)
81
    :param function key_func: a callable that returns the domain to rate limit by.
82
    :param bool headers_enabled: whether ``X-RateLimit`` response headers are written.
83
    :param str strategy: the strategy to use. refer to :ref:`ratelimit-strategy`
84
    :param str storage_uri: the storage location. refer to :ref:`ratelimit-conf`
85
    :param dict storage_options: kwargs to pass to the storage implementation upon
86
      instantiation.
87
    :param bool auto_check: whether to automatically check the rate limit in the before_request
88
     chain of the application. default ``True``
89
    :param bool swallow_errors: whether to swallow errors when hitting a rate limit.
90
     An exception will still be logged. default ``False``
91
    :param list in_memory_fallback: a variable list of strings denoting fallback
92
     limits to apply when the storage is down.
93
    """
94
95
    def __init__(self, app=None
96
                 , key_func=None
97
                 , global_limits=[]
98
                 , default_limits=[]
99
                 , application_limits=[]
100
                 , headers_enabled=False
101
                 , strategy=None
102
                 , storage_uri=None
103
                 , storage_options={}
104
                 , auto_check=True
105
                 , swallow_errors=False
106
                 , in_memory_fallback=[]
107
                 , retry_after=None
108
    ):
109
        self.app = app
110
        self.logger = logging.getLogger("flask-limiter")
111
112
        self.enabled = True
113
        self._default_limits = []
114
        self._application_limits = []
115
        self._in_memory_fallback = []
116
        self._exempt_routes = set()
117
        self._request_filters = []
118
        self._headers_enabled = headers_enabled
119
        self._header_mapping = {}
120
        self._retry_after = retry_after
121
        self._strategy = strategy
122
        self._storage_uri = storage_uri
123
        self._storage_options = storage_options
124
        self._auto_check = auto_check
125
        self._swallow_errors = swallow_errors
126
        if not key_func:
127
            warnings.warn(
128
                "Use of the default `get_ipaddr` function is discouraged."
129
                " Please refer to https://flask-limiter.readthedocs.org/#rate-limit-domain"
130
                " for the recommended configuration",
131
                UserWarning
132
            )
133
        if global_limits:
134
            self.raise_global_limits_warning()
135
136
        self._key_func = key_func or get_ipaddr
137
        for limit in set(global_limits + default_limits):
138
            self._default_limits.extend(
139
                [
140
                    ExtLimit(
141
                        limit, self._key_func, None, False, None, None, None
142
                    ) for limit in parse_many(limit)
143
                ]
144
            )
145
        for limit in application_limits:
146
            self._application_limits.extend(
147
                [
148
                    ExtLimit(
149
                        limit, self._key_func, "global", False, None, None, None
150
                    ) for limit in parse_many(limit)
151
                ]
152
            )
153
        for limit in in_memory_fallback:
154
            self._in_memory_fallback.extend(
155
                [
156
                    ExtLimit(
157
                        limit, self._key_func, None, False, None, None, None
158
                    ) for limit in parse_many(limit)
159
                    ]
160
            )
161
        self._route_limits = {}
162
        self._dynamic_route_limits = {}
163
        self._blueprint_limits = {}
164
        self._blueprint_dynamic_limits = {}
165
        self._blueprint_exempt = set()
166
        self._storage = self._limiter = None
167
        self._storage_dead = False
168
        self._fallback_limiter = None
169
        self.__check_backend_count = 0
170
        self.__last_check_backend = time.time()
171
172
        class BlackHoleHandler(logging.StreamHandler):
173
            def emit(*_):
174
                return
175
        self.logger.addHandler(BlackHoleHandler())
176
        if app:
177
            self.init_app(app)
178
179
    def init_app(self, app):
180
        """
181
        :param app: :class:`flask.Flask` instance to rate limit.
182
        """
183
        self.enabled = app.config.setdefault(C.ENABLED, True)
184
        self._swallow_errors = app.config.setdefault(
185
            C.SWALLOW_ERRORS, self._swallow_errors
186
        )
187
        self._headers_enabled = (
188
            self._headers_enabled
189
            or app.config.setdefault(C.HEADERS_ENABLED, False)
190
        )
191
        self._storage_options.update(
192
            app.config.get(C.STORAGE_OPTIONS, {})
193
        )
194
        self._storage = storage_from_string(
195
            self._storage_uri
196
            or app.config.setdefault(C.STORAGE_URL, 'memory://'),
197
            ** self._storage_options
198
        )
199
        strategy = (
200
            self._strategy
201
            or app.config.setdefault(C.STRATEGY, 'fixed-window')
202
        )
203
        if strategy not in STRATEGIES:
204
            raise ConfigurationError("Invalid rate limiting strategy %s" % strategy)
205
        self._limiter = STRATEGIES[strategy](self._storage)
206
        self._header_mapping.update({
207
           HEADERS.RESET : self._header_mapping.get(HEADERS.RESET,None) or app.config.setdefault(C.HEADER_RESET, "X-RateLimit-Reset"),
208
           HEADERS.REMAINING : self._header_mapping.get(HEADERS.REMAINING,None) or app.config.setdefault(C.HEADER_REMAINING, "X-RateLimit-Remaining"),
209
           HEADERS.LIMIT : self._header_mapping.get(HEADERS.LIMIT,None) or app.config.setdefault(C.HEADER_LIMIT, "X-RateLimit-Limit"),
210
           HEADERS.RETRY_AFTER : self._header_mapping.get(HEADERS.RETRY_AFTER,None) or app.config.setdefault(C.HEADER_RETRY_AFTER, "Retry-After"),
211
        })
212
        self._retry_after = (
213
            self._retry_after
214
            or app.config.get(C.HEADER_RETRY_AFTER_VALUE)
215
        )
216
217
        app_limits = app.config.get(C.APPLICATION_LIMITS, None)
218
        if not self._application_limits and app_limits:
219
            self._application_limits = [
220
                ExtLimit(
221
                    limit, self._key_func, "global", False, None, None, None
222
                ) for limit in parse_many(app_limits)
223
            ]
224
225
        if app.config.get(C.GLOBAL_LIMITS, None):
226
            self.raise_global_limits_warning()
227
        conf_limits = app.config.get(C.GLOBAL_LIMITS, app.config.get(C.DEFAULT_LIMITS, None))
228
        if not self._default_limits and conf_limits:
229
            self._default_limits = [
230
                ExtLimit(
231
                    limit, self._key_func, None, False, None, None, None
232
                ) for limit in parse_many(conf_limits)
233
            ]
234
        fallback_limits = app.config.get(C.IN_MEMORY_FALLBACK, None)
235
        if not self._in_memory_fallback and fallback_limits:
236
            self._in_memory_fallback = [
237
                ExtLimit(
238
                    limit, self._key_func, None, False, None, None, None
239
                ) for limit in parse_many(fallback_limits)
240
                ]
241
        if self._auto_check:
242
            app.before_request(self.__check_request_limit)
243
        app.after_request(self.__inject_headers)
244
245
        if self._in_memory_fallback:
246
            self._fallback_storage = MemoryStorage()
247
            self._fallback_limiter = STRATEGIES[strategy](self._fallback_storage)
248
249
        # purely for backward compatibility as stated in flask documentation
250
        if not hasattr(app, 'extensions'):
251
            app.extensions = {} # pragma: no cover
252
        app.extensions['limiter'] = self
253
254
    def __should_check_backend(self):
255
        if self.__check_backend_count > MAX_BACKEND_CHECKS:
256
            self.__check_backend_count = 0
257
        if time.time() - self.__last_check_backend > pow(2, self.__check_backend_count):
258
            self.__last_check_backend = time.time()
259
            self.__check_backend_count += 1
260
            return True
261
        return False
262
263
    def check(self):
264
        """
265
        check the limits for the current request
266
267
        :raises: RateLimitExceeded
268
        """
269
        self.__check_request_limit()
270
271
    def reset(self):
272
        """
273
        resets the storage if it supports being reset
274
        """
275
        try:
276
            self._storage.reset()
277
            self.logger.info("Storage has been reset and all limits cleared")
278
        except NotImplementedError:
279
            self.logger.warning("This storage type does not support being reset")
280
281
    @property
282
    def limiter(self):
283
        if self._storage_dead and self._in_memory_fallback:
284
            return self._fallback_limiter
285
        else:
286
            return self._limiter
287
288
    def __inject_headers(self, response):
289
        current_limit = getattr(g, 'view_rate_limit', None)
290
        if self.enabled and self._headers_enabled and current_limit:
291
            window_stats = self.limiter.get_window_stats(*current_limit)
292
            response.headers.add(
293
                self._header_mapping[HEADERS.LIMIT],
294
                str(current_limit[0].amount)
295
            )
296
            response.headers.add(
297
                self._header_mapping[HEADERS.REMAINING],
298
                window_stats[1]
299
            )
300
            response.headers.add(
301
                self._header_mapping[HEADERS.RESET],
302
                window_stats[0]
303
            )
304
            response.headers.add(
305
                self._header_mapping[HEADERS.RETRY_AFTER],
306
                self._retry_after == 'http-date' and http_date(window_stats[0])
307
                    or int(window_stats[0] - time.time())
308
            )
309
        return response
310
311
    def __check_request_limit(self):
312
        endpoint = request.endpoint or ""
313
        view_func = current_app.view_functions.get(endpoint, None)
314
        name = ("%s.%s" % (
315
                view_func.__module__, view_func.__name__
316
            ) if view_func else ""
317
        )
318
        if (not request.endpoint
319
            or not self.enabled
320
            or view_func == current_app.send_static_file
321
            or name in self._exempt_routes
322
            or request.blueprint in self._blueprint_exempt
323
            or any(fn() for fn in self._request_filters)
324
        ):
325
            return
326
        limits = (
327
            name in self._route_limits and self._route_limits[name]
328
            or []
329
        )
330
        dynamic_limits = []
331
        if name in self._dynamic_route_limits:
332
            for lim in self._dynamic_route_limits[name]:
333
                try:
334
                    dynamic_limits.extend(
335
                        ExtLimit(
336
                            limit, lim.key_func, lim.scope, lim.per_method,
337
                            lim.methods, lim.error_message, lim.exempt_when
338
                        ) for limit in parse_many(lim.limit)
339
                    )
340
                except ValueError as e:
341
                    self.logger.error(
342
                        "failed to load ratelimit for view function %s (%s)"
343
                        , name, e
344
                    )
345
        if request.blueprint:
346
            if (request.blueprint in self._blueprint_dynamic_limits
347
                and not dynamic_limits
348
            ):
349
                for lim in self._blueprint_dynamic_limits[request.blueprint]:
350
                    try:
351
                        dynamic_limits.extend(
352
                            ExtLimit(
353
                                limit, lim.key_func, lim.scope, lim.per_method,
354
                                lim.methods, lim.error_message, lim.exempt_when
355
                            ) for limit in parse_many(lim.limit)
356
                        )
357
                    except ValueError as e:
358
                        self.logger.error(
359
                            "failed to load ratelimit for blueprint %s (%s)"
360
                            , request.blueprint, e
361
                        )
362
            if (request.blueprint in self._blueprint_limits
363
                and not limits
364
            ):
365
               limits.extend(self._blueprint_limits[request.blueprint])
366
367
        failed_limit = None
368
        limit_for_header = None
369
        try:
370
            all_limits = []
371
            if self._storage_dead and self._fallback_limiter:
372
                if self.__should_check_backend() and self._storage.check():
373
                    self.logger.info(
374
                        "Rate limit storage recovered"
375
                    )
376
                    self._storage_dead = False
377
                    self.__check_backend_count = 0
378
                else:
379
                    all_limits = self._in_memory_fallback
380
            if not all_limits:
381
                all_limits = self._application_limits + (limits + dynamic_limits or self._default_limits)
382
            for lim in all_limits:
383
                limit_scope = lim.scope or endpoint
384
                if lim.is_exempt:
385
                    return
386
                if lim.methods is not None and request.method.lower() not in lim.methods:
387
                    return
388
                if lim.per_method:
389
                    limit_scope += ":%s" % request.method
390
                if not limit_for_header or lim.limit < limit_for_header[0]:
391
                    limit_for_header = (lim.limit, lim.key_func(), limit_scope)
392
                if not self.limiter.hit(lim.limit, lim.key_func(), limit_scope):
393
                    self.logger.warning(
394
                        "ratelimit %s (%s) exceeded at endpoint: %s"
395
                        , lim.limit, lim.key_func(), limit_scope
396
                    )
397
                    failed_limit = lim
398
                    limit_for_header = (lim.limit, lim.key_func(), limit_scope)
399
                    break
400
401
            g.view_rate_limit = limit_for_header
402
403
            if failed_limit:
404
                if failed_limit.error_message:
405
                    exc_description = failed_limit.error_message if not callable(
406
                        failed_limit.error_message
407
                    ) else failed_limit.error_message()
408
                else:
409
                    exc_description = six.text_type(failed_limit.limit)
410
                raise RateLimitExceeded(exc_description)
411
        except Exception as e: # no qa
412
            if isinstance(e, RateLimitExceeded):
413
                six.reraise(*sys.exc_info())
414
            if self._in_memory_fallback and not self._storage_dead:
415
                self.logger.warn(
416
                    "Rate limit storage unreachable - falling back to"
417
                    " in-memory storage"
418
                )
419
                self._storage_dead = True
420
                self.__check_request_limit()
421
            else:
422
                if self._swallow_errors:
423
                    self.logger.exception(
424
                        "Failed to rate limit. Swallowing error"
425
                    )
426
                else:
427
                    six.reraise(*sys.exc_info())
428
429
    def __limit_decorator(self, limit_value,
430
                          key_func=None, shared=False,
431
                          scope=None,
432
                          per_method=False,
433
                          methods=None,
434
                          error_message=None,
435
                          exempt_when=None):
436
        _scope = scope if shared else None
437
438
        def _inner(obj):
439
            func = key_func or self._key_func
440
            is_route = not isinstance(obj, Blueprint)
441
            name = "%s.%s" % (obj.__module__, obj.__name__) if is_route else obj.name
442
            dynamic_limit, static_limits = None, []
443
            if callable(limit_value):
444
                dynamic_limit = ExtLimit(limit_value, func, _scope, per_method,
445
                                         methods, error_message, exempt_when)
446
            else:
447
                try:
448
                    static_limits = [ExtLimit(
449
                        limit, func, _scope, per_method,
450
                        methods, error_message, exempt_when
451
                    ) for limit in parse_many(limit_value)]
452
                except ValueError as e:
453
                    self.logger.error(
454
                        "failed to configure %s %s (%s)",
455
                        "view function" if is_route else "blueprint", name, e
456
                    )
457
            if isinstance(obj, Blueprint):
458
                if dynamic_limit:
459
                    self._blueprint_dynamic_limits.setdefault(name, []).append(
460
                        dynamic_limit
461
                    )
462
                else:
463
                    self._blueprint_limits.setdefault(name, []).extend(
464
                        static_limits
465
                    )
466
            else:
467
                @wraps(obj)
468
                def __inner(*a, **k):
469
                    return obj(*a, **k)
470
                if dynamic_limit:
471
                    self._dynamic_route_limits.setdefault(name, []).append(
472
                        dynamic_limit
473
                    )
474
                else:
475
                    self._route_limits.setdefault(name, []).extend(
476
                        static_limits
477
                    )
478
                return __inner
479
        return _inner
480
481
482
    def limit(self, limit_value, key_func=None, per_method=False,
483
              methods=None, error_message=None, exempt_when=None):
484
        """
485
        decorator to be used for rate limiting individual routes or blueprints.
486
487
        :param limit_value: rate limit string or a callable that returns a string.
488
         :ref:`ratelimit-string` for more details.
489
        :param function key_func: function/lambda to extract the unique identifier for
490
         the rate limit. defaults to remote address of the request.
491
        :param bool per_method: whether the limit is sub categorized into the http
492
         method of the request.
493
        :param list methods: if specified, only the methods in this list will be rate
494
         limited (default: None).
495
        :param error_message: string (or callable that returns one) to override the
496
         error message used in the response.
497
        :return:
498
        """
499
        return self.__limit_decorator(limit_value, key_func, per_method=per_method,
500
                                      methods=methods, error_message=error_message,
501
                                      exempt_when=exempt_when)
502
503
504
    def shared_limit(self, limit_value, scope, key_func=None,
505
                     error_message=None, exempt_when=None):
506
        """
507
        decorator to be applied to multiple routes sharing the same rate limit.
508
509
        :param limit_value: rate limit string or a callable that returns a string.
510
         :ref:`ratelimit-string` for more details.
511
        :param scope: a string or callable that returns a string
512
         for defining the rate limiting scope.
513
        :param function key_func: function/lambda to extract the unique identifier for
514
         the rate limit. defaults to remote address of the request.
515
        :param error_message: string (or callable that returns one) to override the
516
         error message used in the response.
517
        """
518
        return self.__limit_decorator(
519
            limit_value, key_func, True, scope, error_message=error_message,
520
            exempt_when=exempt_when
521
        )
522
523
524
    def exempt(self, obj):
525
        """
526
        decorator to mark a view or all views in a blueprint as exempt from rate limits.
527
        """
528
        if not isinstance(obj, Blueprint):
529
            name = "%s.%s" % (obj.__module__, obj.__name__)
530
            @wraps(obj)
531
            def __inner(*a, **k):
532
                return obj(*a, **k)
533
            self._exempt_routes.add(name)
534
            return __inner
535
        else:
536
            self._blueprint_exempt.add(obj.name)
537
538
    def request_filter(self, fn):
539
        """
540
        decorator to mark a function as a filter to be executed
541
        to check if the request is exempt from rate limiting.
542
        """
543
        self._request_filters.append(fn)
544
        return fn
545
546
547
    def raise_global_limits_warning(self):
548
        warnings.warn(
549
            "global_limits was a badly name configuration since it is actually a default limit and not a "
550
            " globally shared limit. Use default_limits if you want to provide a default or use application_limits "
551
            " if you intend to really have a global shared limit",
552
            UserWarning
553
        )
554