Passed
Push — master ( d2c491...6ef235 )
by Jochen
01:44
created

byceps.database.insert_ignore_on_conflict()   A

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 2
dl 0
loc 10
rs 10
c 0
b 0
f 0
1
"""
2
byceps.database
3
~~~~~~~~~~~~~~~
4
5
Database utilities.
6
7
:Copyright: 2006-2019 Jochen Kupperschmidt
8
:License: Modified BSD, see LICENSE for details.
9
"""
10
11
from typing import Any, Callable, Dict, Iterable, TypeVar
12
import uuid
13
14
from sqlalchemy.dialects.postgresql import insert
15
from sqlalchemy.sql.dml import Insert
16
from sqlalchemy.sql.schema import Table
17
18
from flask_sqlalchemy import BaseQuery, Pagination, SQLAlchemy
19
from sqlalchemy.dialects.postgresql import JSONB, UUID
20
from sqlalchemy.orm import Query
21
22
23
F = TypeVar('F')
24
T = TypeVar('T')
25
26
Mapper = Callable[[F], T]
27
28
29
db = SQLAlchemy(session_options={'autoflush': False})
30
31
32
db.JSONB = JSONB
33
34
35
class Uuid(UUID):
36
37
    def __init__(self):
38
        super().__init__(as_uuid=True)
39
40
41
db.Uuid = Uuid
42
43
44
def generate_uuid() -> uuid.UUID:
45
    """Generate a random UUID (Universally Unique IDentifier)."""
46
    return uuid.uuid4()
47
48
49
def paginate(query: Query, page: int, per_page: int,
50
             *, item_mapper: Mapper=None) -> Pagination:
51
    """Return `per_page` items from page `page`."""
52
    if page < 1:
53
        page = 1
54
55
    if per_page < 1:
56
        raise ValueError('The number of items per page must be positive.')
57
58
    offset = (page - 1) * per_page
59
60
    items = query \
61
        .limit(per_page) \
62
        .offset(offset) \
63
        .all()
64
65
    item_count = len(items)
66
    if page == 1 and item_count < per_page:
67
        total = item_count
68
    else:
69
        total = query.order_by(None).count()
70
71
    if item_mapper is not None:
72
        items = [item_mapper(item) for item in items]
73
74
    # Intentionally pass no query object.
75
    return Pagination(None, page, per_page, total, items)
76
77
78
def insert_ignore_on_conflict(table: Table, values: Dict[str, Any]) -> None:
79
    """Insert the record identified by the primary key (specified as
80
    part of the values), or do nothing on conflict.
81
    """
82
    query = insert(table) \
83
        .values(**values) \
84
        .on_conflict_do_nothing(constraint=table.primary_key)
85
86
    db.session.execute(query)
87
    db.session.commit()
88
89
90
def upsert(table: Table, identifier: Dict[str, Any],
91
           replacement: Dict[str, Any]) -> None:
92
    """Insert or update the record identified by `identifier` with value
93
    `replacement`.
94
    """
95
    query = _build_upsert_query(table, identifier, replacement)
96
97
    db.session.execute(query)
98
    db.session.commit()
99
100
101
def upsert_many(table: Table, identifiers: Iterable[Dict[str, Any]],
102
                replacement: Dict[str, Any]) -> None:
103
    """Insert or update the record identified by `identifier` with value
104
    `replacement`.
105
    """
106
    for identifier in identifiers:
107
        query = _build_upsert_query(table, identifier, replacement)
108
        db.session.execute(query)
109
110
    db.session.commit()
111
112
113
def _build_upsert_query(table: Table, identifier: Dict[str, Any],
114
                        replacement: Dict[str, Any]) -> Insert:
115
    values = identifier.copy()
116
    values.update(replacement)
117
118
    return insert(table) \
119
        .values(**values) \
120
        .on_conflict_do_update(
121
            constraint=table.primary_key,
122
            set_=replacement)
123