Passed
Push — master ( 105c1e...a48eb1 )
by Swen
02:08
created

SchemaEditor.create_model()   A

Complexity

Conditions 3

Size

Total Lines 10

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 6
CRAP Score 3

Importance

Changes 0
Metric Value
cc 3
c 0
b 0
f 0
dl 0
loc 10
rs 9.4285
ccs 6
cts 6
cp 1
crap 3
1 1
import importlib
2
3 1
from django.conf import settings
4 1
from django.core.exceptions import ImproperlyConfigured
5 1
from django.db.backends.postgresql.base import \
6
    DatabaseWrapper as Psycopg2DatabaseWrapper
7
8 1
from ..fields import LocalizedField
9
10
11 1
def _get_backend_base():
12
    """Gets the base class for the custom database back-end.
13
14
    This should be the Django PostgreSQL back-end. However,
15
    some people are already using a custom back-end from
16
    another package. We are nice people and expose an option
17
    that allows them to configure the back-end we base upon.
18
19
    As long as the specified base eventually also has
20
    the PostgreSQL back-end as a base, then everything should
21
    work as intended.
22
    """
23 1
    base_class_name = getattr(
24
        settings,
25
        'LOCALIZED_FIELDS_DB_BACKEND_BASE',
26
        'django.db.backends.postgresql'
27
    )
28
29 1
    base_class_module = importlib.import_module(base_class_name + '.base')
30 1
    base_class = getattr(base_class_module, 'DatabaseWrapper', None)
31
32 1
    if not base_class:
33
        raise ImproperlyConfigured((
34
            '\'%s\' is not a valid database back-end.'
35
            ' The module does not define a DatabaseWrapper class.'
36
            ' Check the value of LOCALIZED_FIELDS_DB_BACKEND_BASE.'
37
        ))
38
39 1
    if isinstance(base_class, Psycopg2DatabaseWrapper):
40
        raise ImproperlyConfigured((
41
            '\'%s\' is not a valid database back-end.'
42
            ' It does inherit from the PostgreSQL back-end.'
43
            ' Check the value of LOCALIZED_FIELDS_DB_BACKEND_BASE.'
44
        ))
45
46 1
    return base_class
47
48
49 1
def _get_schema_editor_base():
50
    """Gets the base class for the schema editor.
51
52
    We have to use the configured base back-end's
53
    schema editor for this."""
54 1
    return _get_backend_base().SchemaEditorClass
55
56
57 1
class SchemaEditor(_get_schema_editor_base()):
58
    """Custom schema editor for hstore indexes.
59
60
    This allows us to put UNIQUE constraints for
61
    localized fields."""
62
63 1
    sql_hstore_unique_create = "CREATE UNIQUE INDEX {name} ON {table}{using} ({columns}){extra}"
64 1
    sql_hstore_unique_drop = "DROP INDEX IF EXISTS {name}"
65
66 1
    @staticmethod
67
    def _hstore_unique_name(model, field, keys):
68
        """Gets the name for a UNIQUE INDEX that applies
69
        to one or more keys in a hstore field.
70
71
        Arguments:
72
            model:
73
                The model the field is a part of.
74
75
            field:
76
                The hstore field to create a
77
                UNIQUE INDEX for.
78
79
            key:
80
                The name of the hstore key
81
                to create the name for.
82
83
                This can also be a tuple
84
                of multiple names.
85
86
        Returns:
87
            The name for the UNIQUE index.
88
        """
89 1
        postfix = '_'.join(keys)
90 1
        return '{table_name}_{field_name}_unique_{postfix}'.format(
91
            table_name=model._meta.db_table,
92
            field_name=field.column,
93
            postfix=postfix
94
        )
95
96 1
    def _drop_hstore_unique(self, model, field, keys):
97
        """Drops a UNIQUE constraint for the specified hstore keys."""
98
99 1
        name = self._hstore_unique_name(model, field, keys)
100 1
        sql = self.sql_hstore_unique_drop.format(name=name)
101 1
        self.execute(sql)
102
103 1
    def _create_hstore_unique(self, model, field, keys):
104
        """Creates a UNIQUE constraint for the specified hstore keys."""
105
106 1
        name = self._hstore_unique_name(model, field, keys)
107 1
        columns = [
108
            '(%s->\'%s\')' % (field.column, key)
109
            for key in keys
110
        ]
111 1
        sql = self.sql_hstore_unique_create.format(
112
            name=name,
113
            table=model._meta.db_table,
114
            using='',
115
            columns=','.join(columns),
116
            extra=''
117
        )
118 1
        self.execute(sql)
119
120 1
    def _update_hstore_constraints(self, model, old_field, new_field):
121
        """Updates the UNIQUE constraints for the specified field."""
122
123 1
        old_uniqueness = getattr(old_field, 'uniqueness', None)
124 1
        new_uniqueness = getattr(new_field, 'uniqueness', None)
125
126 1
        def _compose_keys(constraint):
127 1
            if isinstance(constraint, str):
128 1
                return [constraint]
129
130
            return constraint
131
132
        # drop any old uniqueness constraints
133 1
        if old_uniqueness:
134 1
            for keys in old_uniqueness:
135 1
                self._drop_hstore_unique(
136
                    model,
137
                    old_field,
138
                    _compose_keys(keys)
139
                )
140
141
        # (re-)create uniqueness constraints
142 1
        if new_uniqueness:
143 1
            for keys in new_uniqueness:
144 1
                self._create_hstore_unique(
145
                    model,
146
                    old_field,
147
                    _compose_keys(keys)
148
                )
149
150 1
    def _alter_field(self, model, old_field, new_field, *args, **kwargs):
151
        """Ran when the configuration on a field changed."""
152
153
        super()._alter_field(
154
            model, old_field, new_field,
155
            *args, **kwargs
156
        )
157
158
        is_old_field_localized = isinstance(old_field, LocalizedField)
159
        is_new_field_localized = isinstance(new_field, LocalizedField)
160
161
        if is_old_field_localized or is_new_field_localized:
162
            self._update_hstore_constraints(model, old_field, new_field)
163
164 1
    def create_model(self, model):
165
        """Ran when a new model is created."""
166
167 1
        super().create_model(model)
168
169 1
        for field in model._meta.local_fields:
170 1
            if not isinstance(field, LocalizedField):
171 1
                continue
172
173 1
            self._update_hstore_constraints(model, field, field)
174
175
176 1
class DatabaseWrapper(_get_backend_base()):
177
    """Wraps the standard PostgreSQL database back-end.
178
179
    Overrides the schema editor with our custom
180
    schema editor and makes sure the `hstore`
181
    extension is enabled."""
182
183 1
    SchemaEditorClass = SchemaEditor
184
185 1
    def prepare_database(self):
186
        """Ran to prepare the configured database.
187
188
        This is where we enable the `hstore` extension
189
        if it wasn't enabled yet."""
190
191 1
        super().prepare_database()
192 1
        with self.cursor() as cursor:
193
            cursor.execute('CREATE EXTENSION IF NOT EXISTS hstore')
194