Completed
Push — main ( dc9c2e...80557c )
by Jochen
05:20
created

byceps.services.party.service.get_active_parties()   A

Complexity

Conditions 4

Size

Total Lines 25
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 10
CRAP Score 4.0119

Importance

Changes 0
Metric Value
cc 4
eloc 16
nop 3
dl 0
loc 25
ccs 10
cts 11
cp 0.9091
crap 4.0119
rs 9.6
c 0
b 0
f 0
1
"""
2
byceps.services.party.service
3
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
4
5
:Copyright: 2006-2020 Jochen Kupperschmidt
6
:License: Modified BSD, see LICENSE for details.
7
"""
8
9 1
import dataclasses
10 1
from datetime import date, datetime, timedelta
11 1
from typing import Dict, List, Optional, Set, Union
12
13 1
from ...database import db, paginate, Pagination
14 1
from ...typing import BrandID, PartyID
15
16 1
from ..brand.models.brand import Brand as DbBrand
17 1
from ..brand import service as brand_service
18
19 1
from .models.party import Party as DbParty
20 1
from .models.setting import Setting as DbSetting
21 1
from .transfer.models import Party, PartyWithBrand
22
23
24 1
class UnknownPartyId(Exception):
25 1
    pass
26
27
28 1
def create_party(
29
    party_id: PartyID,
30
    brand_id: BrandID,
31
    title: str,
32
    starts_at: datetime,
33
    ends_at: datetime,
34
    *,
35
    max_ticket_quantity: Optional[int] = None,
36
) -> Party:
37
    """Create a party."""
38 1
    party = DbParty(
39
        party_id,
40
        brand_id,
41
        title,
42
        starts_at,
43
        ends_at,
44
        max_ticket_quantity=max_ticket_quantity,
45
    )
46
47 1
    db.session.add(party)
48 1
    db.session.commit()
49
50 1
    return _db_entity_to_party(party)
51
52
53 1
def update_party(
54
    party_id: PartyID,
55
    title: str,
56
    starts_at: datetime,
57
    ends_at: datetime,
58
    max_ticket_quantity: Optional[int],
59
    ticket_management_enabled: bool,
60
    seat_management_enabled: bool,
61
    canceled: bool,
62
    archived: bool,
63
) -> Party:
64
    """Update a party."""
65
    party = DbParty.query.get(party_id)
66
67
    if party is None:
68
        raise UnknownPartyId(party_id)
69
70
    party.title = title
71
    party.starts_at = starts_at
72
    party.ends_at = ends_at
73
    party.max_ticket_quantity = max_ticket_quantity
74
    party.ticket_management_enabled = ticket_management_enabled
75
    party.seat_management_enabled = seat_management_enabled
76
    party.canceled = canceled
77
    party.archived = archived
78
79
    db.session.commit()
80
81
    return _db_entity_to_party(party)
82
83
84 1
def delete_party(party_id: PartyID) -> None:
85
    """Delete a party."""
86 1
    db.session.query(DbSetting) \
87
        .filter_by(party_id=party_id) \
88
        .delete()
89
90 1
    db.session.query(DbParty) \
91
        .filter_by(id=party_id) \
92
        .delete()
93
94 1
    db.session.commit()
95
96
97 1
def count_parties() -> int:
98
    """Return the number of parties (of all brands)."""
99 1
    return DbParty.query.count()
100
101
102 1
def count_parties_for_brand(brand_id: BrandID) -> int:
103
    """Return the number of parties for that brand."""
104 1
    return DbParty.query \
105
        .filter_by(brand_id=brand_id) \
106
        .count()
107
108
109 1
def find_party(party_id: PartyID) -> Optional[Party]:
110
    """Return the party with that id, or `None` if not found."""
111 1
    party = DbParty.query.get(party_id)
112
113 1
    if party is None:
114 1
        return None
115
116 1
    return _db_entity_to_party(party)
117
118
119 1
def get_party(party_id: PartyID) -> Party:
120
    """Return the party with that id, or `None` if not found."""
121 1
    party = find_party(party_id)
122
123 1
    if party is None:
124
        raise UnknownPartyId(party_id)
125
126 1
    return party
127
128
129 1
def get_all_parties() -> List[Party]:
130
    """Return all parties."""
131 1
    parties = DbParty.query \
132
        .all()
133
134 1
    return [_db_entity_to_party(party) for party in parties]
135
136
137 1
def get_all_parties_with_brands() -> List[PartyWithBrand]:
138
    """Return all parties."""
139 1
    parties = DbParty.query \
140
        .options(db.joinedload('brand')) \
141
        .all()
142
143 1
    return [_db_entity_to_party_with_brand(party) for party in parties]
144
145
146 1
def get_active_parties(
147
    brand_id: Optional[BrandID] = None, *, include_brands: bool = True
148
) -> List[Union[Party, PartyWithBrand]]:
149
    """Return active (i.e. non-canceled, non-archived) parties."""
150 1
    query = DbParty.query
151
152 1
    if brand_id is not None:
153 1
        query = query.filter_by(brand_id=brand_id)
154
155 1
    if include_brands:
156 1
        query = query.options(db.joinedload('brand'))
157
158 1
    parties = query \
159
        .filter_by(canceled=False) \
160
        .filter_by(archived=False) \
161
        .order_by(DbParty.starts_at) \
162
        .all()
163
164
165 1
    if include_brands:
166 1
        transform = _db_entity_to_party_with_brand
167
    else:
168
        transform = _db_entity_to_party
169
170 1
    return [transform(party) for party in parties]
171
172
173 1
def get_archived_parties_for_brand(brand_id: BrandID) -> List[Party]:
174
    """Return archived parties for that brand."""
175
    parties = DbParty.query \
176
        .filter_by(brand_id=brand_id) \
177
        .filter_by(archived=True) \
178
        .order_by(DbParty.starts_at.desc()) \
179
        .all()
180
181
    return [_db_entity_to_party(party) for party in parties]
182
183
184 1
def get_parties(party_ids: Set[PartyID]) -> List[Party]:
185
    """Return the parties with those IDs."""
186 1
    if not party_ids:
187 1
        return []
188
189 1
    parties = DbParty.query \
190
        .filter(DbParty.id.in_(party_ids)) \
191
        .all()
192
193 1
    return [_db_entity_to_party(party) for party in parties]
194
195
196 1
def get_parties_for_brand(brand_id: BrandID) -> List[Party]:
197
    """Return the parties for that brand."""
198 1
    parties = DbParty.query \
199
        .filter_by(brand_id=brand_id) \
200
        .all()
201
202 1
    return [_db_entity_to_party(party) for party in parties]
203
204
205 1
def get_parties_for_brand_paginated(
206
    brand_id: BrandID, page: int, per_page: int
207
) -> Pagination:
208
    """Return the parties for that brand to show on the specified page."""
209 1
    query = DbParty.query \
210
        .filter_by(brand_id=brand_id) \
211
        .order_by(DbParty.starts_at.desc())
212
213 1
    return paginate(query, page, per_page, item_mapper=_db_entity_to_party)
214
215
216 1
def get_party_count_by_brand_id() -> Dict[BrandID, int]:
217
    """Return party count (including 0) per brand, indexed by brand ID."""
218 1
    brand_ids_and_party_counts = db.session \
219
        .query(
220
            DbBrand.id,
221
            db.func.count(DbParty.id)
222
        ) \
223
        .outerjoin(DbParty) \
224
        .group_by(DbBrand.id) \
225
        .all()
226
227 1
    return dict(brand_ids_and_party_counts)
228
229
230 1
def _db_entity_to_party(party: DbParty) -> Party:
231 1
    return Party(
232
        party.id,
233
        party.brand_id,
234
        party.title,
235
        party.starts_at,
236
        party.ends_at,
237
        party.max_ticket_quantity,
238
        party.ticket_management_enabled,
239
        party.seat_management_enabled,
240
        party.canceled,
241
        party.archived,
242
    )
243
244
245 1
def _db_entity_to_party_with_brand(party_entity: DbParty) -> PartyWithBrand:
246 1
    party = _db_entity_to_party(party_entity)
247 1
    brand = brand_service._db_entity_to_brand(party_entity.brand)
248
249 1
    return PartyWithBrand(*dataclasses.astuple(party), brand=brand)
250
251
252 1
def get_party_days(party: Party) -> List[date]:
253
    """Return the sequence of dates on which the party happens."""
254 1
    starts_on = party.starts_at.date()
255 1
    ends_on = party.ends_at.date()
256
257 1
    def _generate():
258 1
        if starts_on > ends_on:
259
            raise ValueError('Start date must not be after end date.')
260
261 1
        day_step = timedelta(days=1)
262 1
        day = starts_on
263 1
        while True:
264 1
            yield day
265 1
            day += day_step
266 1
            if day > ends_on:
267 1
                return
268
269
    return list(_generate())
270