Completed
Push — master ( 99002b...8124bf )
by Bertrand
49s
created

cachalot.inner()   A

Complexity

Conditions 1

Size

Total Lines 4

Duplication

Lines 0
Ratio 0 %
Metric Value
cc 1
dl 0
loc 4
rs 10
1
# coding: utf-8
2
3
from __future__ import unicode_literals
4
from collections import Iterable
5
from functools import wraps
6
from time import time
7
8
from django.db.backends.utils import CursorWrapper
9
from django.db.models.query import EmptyResultSet
10
from django.db.models.signals import post_migrate
11
from django.db.models.sql.compiler import (
12
    SQLCompiler, SQLInsertCompiler, SQLUpdateCompiler, SQLDeleteCompiler)
13
from django.db.transaction import Atomic, get_connection
14
15
from .api import invalidate
16
from .cache import cachalot_caches
17
from .settings import cachalot_settings
18
from .utils import (
19
    _get_query_cache_key, _get_table_cache_keys, _get_tables_from_sql,
20
    _invalidate_table, UncachableQuery, TUPLE_OR_LIST)
21
22
23
WRITE_COMPILERS = (SQLInsertCompiler, SQLUpdateCompiler, SQLDeleteCompiler)
24
25
26
def _unset_raw_connection(original):
27
    def inner(compiler, *args, **kwargs):
28
        compiler.connection.raw = False
29
        out = original(compiler, *args, **kwargs)
30
        compiler.connection.raw = True
31
        return out
32
    return inner
33
34
35
def _get_result_or_execute_query(execute_query_func, cache,
36
                                 cache_key, table_cache_keys):
37
    data = cache.get_many(table_cache_keys + [cache_key])
38
39
    new_table_cache_keys = set(table_cache_keys)
40
    new_table_cache_keys.difference_update(data)
41
42
    if new_table_cache_keys:
43
        now = time()
44
        cache.set_many({k: now for k in new_table_cache_keys}, None)
45
    elif cache_key in data:
46
        timestamp, result = data.pop(cache_key)
47
        table_times = data.values()
48
        if table_times and timestamp > max(table_times):
49
            return result
50
51
    result = execute_query_func()
52
    if isinstance(result, Iterable) and result.__class__ not in TUPLE_OR_LIST:
53
        result = list(result)
54
55
    cache.set(cache_key, (time(), result), None)
56
57
    return result
58
59
60
def _patch_compiler(original):
61
    @wraps(original)
62
    @_unset_raw_connection
63
    def inner(compiler, *args, **kwargs):
64
        execute_query_func = lambda: original(compiler, *args, **kwargs)
65
        if not cachalot_settings.CACHALOT_ENABLED \
66
                or isinstance(compiler, WRITE_COMPILERS):
67
            return execute_query_func()
68
69
        try:
70
            cache_key = _get_query_cache_key(compiler)
71
            table_cache_keys = _get_table_cache_keys(compiler)
72
        except (EmptyResultSet, UncachableQuery):
73
            return execute_query_func()
74
75
        return _get_result_or_execute_query(
76
            execute_query_func,
77
            cachalot_caches.get_cache(db_alias=compiler.using),
78
            cache_key, table_cache_keys)
79
80
    return inner
81
82
83
def _patch_write_compiler(original):
84
    @wraps(original)
85
    @_unset_raw_connection
86
    def inner(write_compiler, *args, **kwargs):
87
        db_alias = write_compiler.using
88
        table = write_compiler.query.get_meta().db_table
89
        _invalidate_table(cachalot_caches.get_cache(db_alias=db_alias),
90
                          db_alias, table)
91
        return original(write_compiler, *args, **kwargs)
92
93
    return inner
94
95
96
def _patch_orm():
97
    SQLCompiler.execute_sql = _patch_compiler(SQLCompiler.execute_sql)
98
    for compiler in WRITE_COMPILERS:
99
        compiler.execute_sql = _patch_write_compiler(compiler.execute_sql)
100
101
102
def _patch_cursor():
103
    def _patch_cursor_execute(original):
104
        @wraps(original)
105
        def inner(cursor, sql, *args, **kwargs):
106
            out = original(cursor, sql, *args, **kwargs)
107
            if getattr(cursor.db, 'raw', True) \
108
                    and cachalot_settings.CACHALOT_INVALIDATE_RAW:
109
                sql = sql.lower()
110
                if 'update' in sql or 'insert' in sql or 'delete' in sql:
111
                    tables = _get_tables_from_sql(cursor.db, sql)
112
                    invalidate(*tables, db_alias=cursor.db.alias)
113
            return out
114
115
        return inner
116
117
    CursorWrapper.execute = _patch_cursor_execute(CursorWrapper.execute)
118
    CursorWrapper.executemany = _patch_cursor_execute(CursorWrapper.executemany)
119
120
121
def _patch_atomic():
122
    def patch_enter(original):
123
        @wraps(original)
124
        def inner(self):
125
            cachalot_caches.enter_atomic(self.using)
126
            original(self)
127
128
        return inner
129
130
    def patch_exit(original):
131
        @wraps(original)
132
        def inner(self, exc_type, exc_value, traceback):
133
            needs_rollback = get_connection(self.using).needs_rollback
134
            original(self, exc_type, exc_value, traceback)
135
            cachalot_caches.exit_atomic(
136
                self.using, exc_type is None and not needs_rollback)
137
138
        return inner
139
140
    Atomic.__enter__ = patch_enter(Atomic.__enter__)
141
    Atomic.__exit__ = patch_exit(Atomic.__exit__)
142
143
144
def _invalidate_on_migration(sender, **kwargs):
145
    invalidate(*sender.get_models(), db_alias=kwargs['using'])
146
147
148
def patch():
149
    post_migrate.connect(_invalidate_on_migration)
150
151
    _patch_cursor()
152
    _patch_atomic()
153
    _patch_orm()
154