1
|
|
|
""" |
2
|
|
|
byceps.database |
3
|
|
|
~~~~~~~~~~~~~~~ |
4
|
|
|
|
5
|
|
|
Database utilities. |
6
|
|
|
|
7
|
|
|
:Copyright: 2006-2019 Jochen Kupperschmidt |
8
|
|
|
:License: Modified BSD, see LICENSE for details. |
9
|
|
|
""" |
10
|
|
|
|
11
|
|
|
from typing import Any, Callable, Dict, Iterable, TypeVar |
12
|
|
|
import uuid |
13
|
|
|
|
14
|
|
|
from sqlalchemy.dialects.postgresql import insert |
15
|
|
|
from sqlalchemy.sql.dml import Insert |
16
|
|
|
from sqlalchemy.sql.schema import Table |
17
|
|
|
|
18
|
|
|
from flask_sqlalchemy import BaseQuery, Pagination, SQLAlchemy |
19
|
|
|
from sqlalchemy.dialects.postgresql import JSONB, UUID |
20
|
|
|
from sqlalchemy.orm import Query |
21
|
|
|
|
22
|
|
|
|
23
|
|
|
F = TypeVar('F') |
24
|
|
|
T = TypeVar('T') |
25
|
|
|
|
26
|
|
|
Mapper = Callable[[F], T] |
27
|
|
|
|
28
|
|
|
|
29
|
|
|
db = SQLAlchemy(session_options={'autoflush': False}) |
30
|
|
|
|
31
|
|
|
|
32
|
|
|
db.JSONB = JSONB |
33
|
|
|
|
34
|
|
|
|
35
|
|
|
class Uuid(UUID): |
36
|
|
|
|
37
|
|
|
def __init__(self): |
38
|
|
|
super().__init__(as_uuid=True) |
39
|
|
|
|
40
|
|
|
|
41
|
|
|
db.Uuid = Uuid |
42
|
|
|
|
43
|
|
|
|
44
|
|
|
def generate_uuid() -> uuid.UUID: |
45
|
|
|
"""Generate a random UUID (Universally Unique IDentifier).""" |
46
|
|
|
return uuid.uuid4() |
47
|
|
|
|
48
|
|
|
|
49
|
|
|
def paginate(query: Query, page: int, per_page: int, |
50
|
|
|
*, item_mapper: Mapper=None) -> Pagination: |
51
|
|
|
"""Return `per_page` items from page `page`.""" |
52
|
|
|
if page < 1: |
53
|
|
|
page = 1 |
54
|
|
|
|
55
|
|
|
if per_page < 1: |
56
|
|
|
raise ValueError('The number of items per page must be positive.') |
57
|
|
|
|
58
|
|
|
offset = (page - 1) * per_page |
59
|
|
|
|
60
|
|
|
items = query \ |
61
|
|
|
.limit(per_page) \ |
62
|
|
|
.offset(offset) \ |
63
|
|
|
.all() |
64
|
|
|
|
65
|
|
|
item_count = len(items) |
66
|
|
|
if page == 1 and item_count < per_page: |
67
|
|
|
total = item_count |
68
|
|
|
else: |
69
|
|
|
total = query.order_by(None).count() |
70
|
|
|
|
71
|
|
|
if item_mapper is not None: |
72
|
|
|
items = [item_mapper(item) for item in items] |
73
|
|
|
|
74
|
|
|
# Intentionally pass no query object. |
75
|
|
|
return Pagination(None, page, per_page, total, items) |
76
|
|
|
|
77
|
|
|
|
78
|
|
|
def insert_ignore_on_conflict(table: Table, values: Dict[str, Any]) -> None: |
79
|
|
|
"""Insert the record identified by the primary key (specified as |
80
|
|
|
part of the values), or do nothing on conflict. |
81
|
|
|
""" |
82
|
|
|
query = insert(table) \ |
83
|
|
|
.values(**values) \ |
84
|
|
|
.on_conflict_do_nothing(constraint=table.primary_key) |
85
|
|
|
|
86
|
|
|
db.session.execute(query) |
87
|
|
|
db.session.commit() |
88
|
|
|
|
89
|
|
|
|
90
|
|
|
def upsert(table: Table, identifier: Dict[str, Any], |
91
|
|
|
replacement: Dict[str, Any]) -> None: |
92
|
|
|
"""Insert or update the record identified by `identifier` with value |
93
|
|
|
`replacement`. |
94
|
|
|
""" |
95
|
|
|
query = _build_upsert_query(table, identifier, replacement) |
96
|
|
|
|
97
|
|
|
db.session.execute(query) |
98
|
|
|
db.session.commit() |
99
|
|
|
|
100
|
|
|
|
101
|
|
|
def upsert_many(table: Table, identifiers: Iterable[Dict[str, Any]], |
102
|
|
|
replacement: Dict[str, Any]) -> None: |
103
|
|
|
"""Insert or update the record identified by `identifier` with value |
104
|
|
|
`replacement`. |
105
|
|
|
""" |
106
|
|
|
for identifier in identifiers: |
107
|
|
|
query = _build_upsert_query(table, identifier, replacement) |
108
|
|
|
db.session.execute(query) |
109
|
|
|
|
110
|
|
|
db.session.commit() |
111
|
|
|
|
112
|
|
|
|
113
|
|
|
def _build_upsert_query(table: Table, identifier: Dict[str, Any], |
114
|
|
|
replacement: Dict[str, Any]) -> Insert: |
115
|
|
|
values = identifier.copy() |
116
|
|
|
values.update(replacement) |
117
|
|
|
|
118
|
|
|
return insert(table) \ |
119
|
|
|
.values(**values) \ |
120
|
|
|
.on_conflict_do_update( |
121
|
|
|
constraint=table.primary_key, |
122
|
|
|
set_=replacement) |
123
|
|
|
|