|
1
|
1 |
|
from django.core.exceptions import SuspiciousOperation |
|
|
|
|
|
|
2
|
1 |
|
from django.db.models.sql.compiler import SQLInsertCompiler, SQLUpdateCompiler |
|
|
|
|
|
|
3
|
|
|
|
|
4
|
|
|
|
|
5
|
1 |
|
class PostgresReturningUpdateCompiler(SQLUpdateCompiler): |
|
|
|
|
|
|
6
|
|
|
"""Compiler for SQL UPDATE statements that return |
|
7
|
|
|
the primary keys of the affected rows.""" |
|
8
|
|
|
|
|
9
|
1 |
|
def execute_sql(self, _result_type): |
|
|
|
|
|
|
10
|
1 |
|
sql, params = self.as_sql() |
|
|
|
|
|
|
11
|
1 |
|
sql += self._form_returning() |
|
12
|
|
|
|
|
13
|
1 |
|
with self.connection.cursor() as cursor: |
|
|
|
|
|
|
14
|
1 |
|
cursor.execute(sql, params) |
|
15
|
1 |
|
primary_keys = cursor.fetchall() |
|
16
|
|
|
|
|
17
|
1 |
|
return primary_keys |
|
18
|
|
|
|
|
19
|
1 |
|
def _form_returning(self): |
|
20
|
|
|
"""Builds the RETURNING part of the query.""" |
|
21
|
|
|
|
|
22
|
1 |
|
qn = self.connection.ops.quote_name |
|
|
|
|
|
|
23
|
1 |
|
return ' RETURNING %s' % qn(self.query.model._meta.pk.name) |
|
|
|
|
|
|
24
|
|
|
|
|
25
|
|
|
|
|
26
|
1 |
|
class PostgresInsertCompiler(SQLInsertCompiler): |
|
27
|
|
|
"""Compiler for SQL INSERT statements.""" |
|
28
|
|
|
|
|
29
|
1 |
|
def __init__(self, *args, **kwargs): |
|
30
|
|
|
"""Initializes a new instance of :see:PostgresInsertCompiler.""" |
|
31
|
|
|
|
|
32
|
1 |
|
super().__init__(*args, **kwargs) |
|
33
|
1 |
|
self.qn = self.connection.ops.quote_name |
|
|
|
|
|
|
34
|
|
|
|
|
35
|
1 |
|
def as_sql(self, return_id=False): |
|
36
|
|
|
"""Builds the SQL INSERT statement.""" |
|
37
|
|
|
|
|
38
|
1 |
|
queries = [ |
|
39
|
|
|
self._rewrite_insert(sql, params, return_id) |
|
40
|
|
|
for sql, params in super().as_sql() |
|
41
|
|
|
] |
|
42
|
|
|
|
|
43
|
1 |
|
return queries |
|
44
|
|
|
|
|
45
|
1 |
|
def execute_sql(self, return_id=False): |
|
46
|
|
|
# execute all the generate queries |
|
47
|
1 |
|
with self.connection.cursor() as cursor: |
|
|
|
|
|
|
48
|
1 |
|
rows = [] |
|
49
|
1 |
|
for sql, params in self.as_sql(return_id): |
|
50
|
1 |
|
cursor.execute(sql, params) |
|
51
|
1 |
|
rows.append(cursor.fetchone()) |
|
52
|
|
|
|
|
53
|
|
|
# create a mapping between column names and column value |
|
54
|
1 |
|
return [ |
|
55
|
|
|
{ |
|
56
|
|
|
column.name: row[column_index] |
|
57
|
|
|
for column_index, column in enumerate(cursor.description) if row |
|
58
|
|
|
} |
|
59
|
|
|
for row in rows |
|
60
|
|
|
] |
|
61
|
|
|
|
|
62
|
1 |
|
def _rewrite_insert(self, sql, params, return_id=False): |
|
63
|
|
|
"""Rewrites a formed SQL INSERT query to include |
|
64
|
|
|
the ON CONFLICT clause. |
|
65
|
|
|
|
|
66
|
|
|
Arguments: |
|
67
|
|
|
sql: |
|
68
|
|
|
The SQL INSERT query to rewrite. |
|
69
|
|
|
|
|
70
|
|
|
params: |
|
71
|
|
|
The parameters passed to the query. |
|
72
|
|
|
|
|
73
|
|
|
returning: |
|
74
|
|
|
What to put in the `RETURNING` clause |
|
75
|
|
|
of the resulting query. |
|
76
|
|
|
|
|
77
|
|
|
Returns: |
|
78
|
|
|
A tuple of the rewritten SQL query and new params. |
|
79
|
|
|
""" |
|
80
|
|
|
|
|
81
|
1 |
|
returning = self.qn(self.query.model._meta.pk.name) if return_id else '*' |
|
|
|
|
|
|
82
|
|
|
|
|
83
|
1 |
|
if self.query.conflict_action.value == 'UPDATE': |
|
|
|
|
|
|
84
|
1 |
|
return self._rewrite_insert_update(sql, params, returning) |
|
85
|
1 |
|
elif self.query.conflict_action.value == 'NOTHING': |
|
|
|
|
|
|
86
|
1 |
|
return self._rewrite_insert_nothing(sql, params, returning) |
|
87
|
|
|
|
|
88
|
|
|
raise SuspiciousOperation(( |
|
89
|
|
|
'%s is not a valid conflict action, specify ' |
|
90
|
|
|
'ConflictAction.UPDATE or ConflictAction.NOTHING.' |
|
91
|
|
|
) % str(self.query.conflict_action)) |
|
|
|
|
|
|
92
|
|
|
|
|
93
|
1 |
|
def _rewrite_insert_update(self, sql, params, returning): |
|
94
|
|
|
"""Rewrites a formed SQL INSERT query to include |
|
95
|
|
|
the ON CONFLICT DO UPDATE clause.""" |
|
96
|
|
|
|
|
97
|
1 |
|
update_columns = ', '.join([ |
|
98
|
|
|
'{0} = EXCLUDED.{0}'.format(self.qn(field.column)) |
|
99
|
|
|
for field in self.query.update_fields |
|
|
|
|
|
|
100
|
|
|
]) |
|
101
|
|
|
|
|
102
|
|
|
# build the conflict target, the columns to watch |
|
103
|
|
|
# for conflicts |
|
104
|
1 |
|
conflict_target = self._build_conflict_target() |
|
105
|
|
|
|
|
106
|
1 |
|
index_predicate = self.query.index_predicate |
|
|
|
|
|
|
107
|
|
|
|
|
108
|
1 |
|
sql_template = ( |
|
109
|
|
|
'{insert} ON CONFLICT {conflict_target} DO UPDATE ' |
|
110
|
|
|
'SET {update_columns} RETURNING {returning}' |
|
111
|
|
|
) |
|
112
|
|
|
|
|
113
|
1 |
|
if index_predicate: |
|
114
|
1 |
|
sql_template = ( |
|
115
|
|
|
'{insert} ON CONFLICT {conflict_target} WHERE {index_predicate} DO UPDATE ' |
|
116
|
|
|
'SET {update_columns} RETURNING {returning}' |
|
117
|
|
|
) |
|
118
|
|
|
|
|
119
|
1 |
|
return ( |
|
120
|
|
|
sql_template.format( |
|
121
|
|
|
insert=sql, |
|
122
|
|
|
conflict_target=conflict_target, |
|
123
|
|
|
update_columns=update_columns, |
|
124
|
|
|
returning=returning, |
|
125
|
|
|
index_predicate=index_predicate, |
|
126
|
|
|
), |
|
127
|
|
|
params |
|
128
|
|
|
) |
|
129
|
|
|
|
|
130
|
1 |
|
def _rewrite_insert_nothing(self, sql, params, returning): |
|
131
|
|
|
"""Rewrites a formed SQL INSERT query to include |
|
132
|
|
|
the ON CONFLICT DO NOTHING clause.""" |
|
133
|
|
|
|
|
134
|
|
|
# build the conflict target, the columns to watch |
|
135
|
|
|
# for conflicts |
|
136
|
1 |
|
conflict_target = self._build_conflict_target() |
|
137
|
|
|
|
|
138
|
1 |
|
where_clause = ' AND '.join([ |
|
139
|
|
|
'{0} = %s'.format(self._format_field_name(field_name)) |
|
140
|
|
|
for field_name in self.query.conflict_target |
|
|
|
|
|
|
141
|
|
|
]) |
|
142
|
|
|
|
|
143
|
1 |
|
where_clause_params = [ |
|
144
|
|
|
self._format_field_value(field_name) |
|
145
|
|
|
for field_name in self.query.conflict_target |
|
|
|
|
|
|
146
|
|
|
] |
|
147
|
|
|
|
|
148
|
1 |
|
params = params + tuple(where_clause_params) |
|
149
|
|
|
|
|
150
|
|
|
# this looks complicated, and it is, but it is for a reason... a normal |
|
151
|
|
|
# ON CONFLICT DO NOTHING doesn't return anything if the row already exists |
|
152
|
|
|
# so we do DO UPDATE instead that never executes to lock the row, and then |
|
153
|
|
|
# select from the table in case we're dealing with an existing row.. |
|
154
|
1 |
|
return ( |
|
155
|
|
|
( |
|
156
|
|
|
'WITH insdata AS (' |
|
157
|
|
|
'{insert} ON CONFLICT {conflict_target} DO UPDATE' |
|
158
|
|
|
' SET id = NULL WHERE FALSE RETURNING {returning})' |
|
159
|
|
|
' SELECT * FROM insdata UNION ALL' |
|
160
|
|
|
' SELECT {returning} FROM {table} WHERE {where_clause} LIMIT 1;' |
|
161
|
|
|
).format( |
|
162
|
|
|
insert=sql, |
|
163
|
|
|
conflict_target=conflict_target, |
|
164
|
|
|
returning=returning, |
|
165
|
|
|
table=self.query.objs[0]._meta.db_table, |
|
|
|
|
|
|
166
|
|
|
where_clause=where_clause |
|
167
|
|
|
), |
|
168
|
|
|
params |
|
169
|
|
|
) |
|
170
|
|
|
|
|
171
|
1 |
|
def _build_conflict_target(self): |
|
172
|
|
|
"""Builds the `conflict_target` for the ON CONFLICT |
|
173
|
|
|
clause.""" |
|
174
|
|
|
|
|
175
|
1 |
|
conflict_target = [] |
|
176
|
|
|
|
|
177
|
1 |
|
if not isinstance(self.query.conflict_target, list): |
|
|
|
|
|
|
178
|
|
|
raise SuspiciousOperation(( |
|
179
|
|
|
'%s is not a valid conflict target, specify ' |
|
180
|
|
|
'a list of column names, or tuples with column ' |
|
181
|
|
|
'names and hstore key.' |
|
182
|
|
|
) % str(self.query.conflict_target)) |
|
|
|
|
|
|
183
|
|
|
|
|
184
|
1 |
|
def _assert_valid_field(field_name): |
|
185
|
1 |
|
field_name = self._normalize_field_name(field_name) |
|
186
|
1 |
|
if self._get_model_field(field_name): |
|
187
|
1 |
|
return |
|
188
|
|
|
|
|
189
|
1 |
|
raise SuspiciousOperation(( |
|
190
|
|
|
'%s is not a valid conflict target, specify ' |
|
191
|
|
|
'a list of column names, or tuples with column ' |
|
192
|
|
|
'names and hstore key.' |
|
193
|
|
|
) % str(field_name)) |
|
194
|
|
|
|
|
195
|
1 |
|
for field_name in self.query.conflict_target: |
|
|
|
|
|
|
196
|
1 |
|
_assert_valid_field(field_name) |
|
197
|
|
|
|
|
198
|
|
|
# special handling for hstore keys |
|
199
|
1 |
|
if isinstance(field_name, tuple): |
|
200
|
1 |
|
conflict_target.append( |
|
201
|
|
|
'(%s->\'%s\')' % ( |
|
202
|
|
|
self._format_field_name(field_name), |
|
203
|
|
|
field_name[1] |
|
204
|
|
|
) |
|
205
|
|
|
) |
|
206
|
|
|
else: |
|
207
|
1 |
|
conflict_target.append( |
|
208
|
|
|
self._format_field_name(field_name)) |
|
209
|
|
|
|
|
210
|
1 |
|
return '(%s)' % ','.join(conflict_target) |
|
211
|
|
|
|
|
212
|
1 |
|
def _get_model_field(self, name: str): |
|
213
|
|
|
"""Gets the field on a model with the specified name. |
|
214
|
|
|
|
|
215
|
|
|
Arguments: |
|
216
|
|
|
name: |
|
217
|
|
|
The name of the field to look for. |
|
218
|
|
|
|
|
219
|
|
|
This can be both the actual field name, or |
|
220
|
|
|
the name of the column, both will work :) |
|
221
|
|
|
|
|
222
|
|
|
Returns: |
|
223
|
|
|
The field with the specified name or None if |
|
224
|
|
|
no such field exists. |
|
225
|
|
|
""" |
|
226
|
|
|
|
|
227
|
1 |
|
field_name = self._normalize_field_name(name) |
|
228
|
|
|
|
|
229
|
|
|
# 'pk' has special meaning and always refers to the primary |
|
230
|
|
|
# key of a model, we have to respect this de-facto standard behaviour |
|
231
|
1 |
|
if field_name == 'pk' and self.query.model._meta.pk: |
|
|
|
|
|
|
232
|
1 |
|
return self.query.model._meta.pk |
|
|
|
|
|
|
233
|
|
|
|
|
234
|
1 |
|
for field in self.query.model._meta.local_concrete_fields: |
|
|
|
|
|
|
235
|
1 |
|
if field.name == field_name or field.column == field_name: |
|
236
|
1 |
|
return field |
|
237
|
|
|
|
|
238
|
1 |
|
return None |
|
239
|
|
|
|
|
240
|
1 |
|
def _format_field_name(self, field_name) -> str: |
|
241
|
|
|
"""Formats a field's name for usage in SQL. |
|
242
|
|
|
|
|
243
|
|
|
Arguments: |
|
244
|
|
|
field_name: |
|
245
|
|
|
The field name to format. |
|
246
|
|
|
|
|
247
|
|
|
Returns: |
|
248
|
|
|
The specified field name formatted for |
|
249
|
|
|
usage in SQL. |
|
250
|
|
|
""" |
|
251
|
|
|
|
|
252
|
1 |
|
field = self._get_model_field(field_name) |
|
253
|
1 |
|
return self.qn(field.column) |
|
254
|
|
|
|
|
255
|
1 |
|
def _format_field_value(self, field_name) -> str: |
|
256
|
|
|
"""Formats a field's value for usage in SQL. |
|
257
|
|
|
|
|
258
|
|
|
Arguments: |
|
259
|
|
|
field_name: |
|
260
|
|
|
The name of the field to format |
|
261
|
|
|
the value of. |
|
262
|
|
|
|
|
263
|
|
|
Returns: |
|
264
|
|
|
The field's value formatted for usage |
|
265
|
|
|
in SQL. |
|
266
|
|
|
""" |
|
267
|
|
|
|
|
268
|
1 |
|
field_name = self._normalize_field_name(field_name) |
|
269
|
1 |
|
field = self._get_model_field(field_name) |
|
270
|
|
|
|
|
271
|
1 |
|
return SQLInsertCompiler.prepare_value( |
|
272
|
|
|
self, |
|
273
|
|
|
field, |
|
274
|
|
|
# Note: this deliberately doesn't use `pre_save_val` as we don't |
|
275
|
|
|
# want things like auto_now on DateTimeField (etc.) to change the |
|
276
|
|
|
# value. We rely on pre_save having already been done by the |
|
277
|
|
|
# underlying compiler so that things like FileField have already had |
|
278
|
|
|
# the opportunity to save out their data. |
|
279
|
|
|
getattr(self.query.objs[0], field.attname) |
|
|
|
|
|
|
280
|
|
|
) |
|
281
|
|
|
|
|
282
|
1 |
|
def _normalize_field_name(self, field_name) -> str: |
|
|
|
|
|
|
283
|
|
|
"""Normalizes a field name into a string by |
|
284
|
|
|
extracting the field name if it was specified |
|
285
|
|
|
as a reference to a HStore key (as a tuple). |
|
286
|
|
|
|
|
287
|
|
|
Arguments: |
|
288
|
|
|
field_name: |
|
289
|
|
|
The field name to normalize. |
|
290
|
|
|
|
|
291
|
|
|
Returns: |
|
292
|
|
|
The normalized field name. |
|
293
|
|
|
""" |
|
294
|
|
|
|
|
295
|
1 |
|
if isinstance(field_name, tuple): |
|
296
|
1 |
|
field_name, _ = field_name |
|
297
|
|
|
|
|
298
|
|
|
return field_name |
|
299
|
|
|
|
This can be caused by one of the following:
1. Missing Dependencies
This error could indicate a configuration issue of Pylint. Make sure that your libraries are available by adding the necessary commands.
2. Missing __init__.py files
This error could also result from missing
__init__.pyfiles in your module folders. Make sure that you place one file in each sub-folder.