Passed
Pull Request — dev (#1034)
by Stephan
03:08 queued 21s
created

data.db.execute_sql_script()   A

Complexity

Conditions 2

Size

Total Lines 19
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 2
dl 0
loc 19
rs 10
c 0
b 0
f 0
1
from contextlib import contextmanager
2
from types import SimpleNamespace
3
import codecs
4
import functools
5
import os
6
import time
7
8
from psycopg2.errors import DeadlockDetected, UniqueViolation
9
from sqlalchemy import create_engine, text
10
from sqlalchemy.exc import IntegrityError, OperationalError
11
from sqlalchemy.orm import sessionmaker
12
import geopandas as gpd
13
import pandas as pd
14
15
from egon.data import config
16
17
18
def asdict(row, conversions=None):
19
    """Convert a result row of an SQLAlchemy query to a dictionary.
20
21
    This helper unifies the conversion of two types of query result rows,
22
    namely instances of mapped classes and keyed tuples, to dictionaries.
23
    That way it's suitable for massaging query results into a format which
24
    can easily be converted to a `pandas` `DataFrame` like this:
25
26
        ```python
27
        df = pandas.DataFrame.from_records(
28
            [asdict(row) for row in session.query(*columns).all()]
29
        )
30
        ```
31
32
    Parameters
33
    ----------
34
    row : SQLAlchemy query result row
35
    conversions : dict
36
        Dictionary mapping column names to functions applied to the values of
37
        that column. The default ist `None` which means no conversion is
38
        applied.
39
40
    Returns
41
    -------
42
    dict
43
        The argument converted to a dictionary with column names as keys and
44
        column values potentially converted by calling
45
        `conversions[column_name](column_value)`.
46
    """
47
    result = None
48
    if hasattr(row, "_asdict"):
49
        result = row._asdict()
50
    if hasattr(row, "__table__"):
51
        result = {
52
            column.name: getattr(row, column.name)
53
            for column in row.__table__.columns
54
        }
55
    if (result is not None) and (conversions is None):
56
        return result
57
    if (result is not None) and (conversions is not None):
58
        return {
59
            k: conversions[k](v) if k in conversions else v
60
            for k, v in result.items()
61
        }
62
    raise TypeError(
63
        "Don't know how to convert `row` argument to dict because it has"
64
        " neither an `_asdict`, nor a `__table__` attribute."
65
    )
66
67
68
@contextmanager
69
def access():
70
    """Provide a context with a session and an associated connection."""
71
    with session_scope() as session, session.connection() as c, c.begin():
72
        yield SimpleNamespace(session=session, connection=c)
73
74
75
def credentials():
76
    """Return local database connection parameters.
77
78
    Returns
79
    -------
80
    dict
81
        Complete DB connection information
82
    """
83
    translated = {
84
        "--database-name": "POSTGRES_DB",
85
        "--database-password": "POSTGRES_PASSWORD",
86
        "--database-host": "HOST",
87
        "--database-port": "PORT",
88
        "--database-user": "POSTGRES_USER",
89
    }
90
    configuration = config.settings()["egon-data"]
91
    update = {
92
        translated[flag]: configuration[flag]
93
        for flag in configuration
94
        if flag in translated
95
    }
96
    configuration.update(update)
97
    return configuration
98
99
100
def engine():
101
    """Engine for local database."""
102
    if not hasattr(engine, "cache"):
103
        engine.cache = {}
104
    pid = os.getpid()
105
    if pid in engine.cache:
106
        return engine.cache[pid]
107
    db_config = credentials()
108
    engine.cache[pid] = create_engine(
109
        f"postgresql+psycopg2://{db_config['POSTGRES_USER']}:"
110
        f"{db_config['POSTGRES_PASSWORD']}@{db_config['HOST']}:"
111
        f"{db_config['PORT']}/{db_config['POSTGRES_DB']}",
112
        echo=False,
113
    )
114
    return engine.cache[pid]
115
116
117
def execute_sql(sql_string):
118
    """Execute a SQL expression given as string.
119
120
    The SQL expression passed as plain string is convert to a
121
    `sqlalchemy.sql.expression.TextClause`.
122
123
    Parameters
124
    ----------
125
    sql_string : str
126
        SQL expression
127
128
    """
129
    with access() as database:
130
        database.connection.execute(text(sql_string))
131
132
133
def submit_comment(json, schema, table):
134
    """Add comment to table.
135
136
    We use `Open Energy Metadata <https://github.com/OpenEnergyPlatform/
137
    oemetadata/blob/develop/metadata/v141/metadata_key_description.md>`_
138
    standard for describing our data. Metadata is stored as JSON in the table
139
    comment.
140
141
    Parameters
142
    ----------
143
    json : str
144
        JSON string reflecting comment
145
    schema : str
146
        The target table's database schema
147
    table : str
148
        Database table on which to put the given comment
149
    """
150
    prefix_str = "COMMENT ON TABLE {0}.{1} IS ".format(schema, table)
151
152
    check_json_str = (
153
        "SELECT obj_description('{0}.{1}'::regclass)::json".format(
154
            schema, table
155
        )
156
    )
157
158
    execute_sql(prefix_str + json + ";")
159
160
    # Query table comment and cast it into JSON
161
    # The query throws an error if JSON is invalid
162
    execute_sql(check_json_str)
163
164
165
def execute_sql_script(script, encoding="utf-8-sig"):
166
    """Execute a SQL script given as a file name.
167
168
    Parameters
169
    ----------
170
    script : str
171
        Path of the SQL-script
172
    encoding : str
173
        Encoding which is used for the SQL file. The default is "utf-8-sig".
174
    Returns
175
    -------
176
    None.
177
178
    """
179
180
    with codecs.open(script, "r", encoding) as fd:
181
        sqlfile = fd.read()
182
183
    execute_sql(sqlfile)
184
185
186
@contextmanager
187
def session_scope():
188
    """Provide a transactional scope around a series of operations."""
189
    Session = sessionmaker(bind=engine())
190
    session = Session()
191
    try:
192
        yield session
193
        session.commit()
194
    except:  # noqa: E722 (This is ok because we immediatey reraise.)
195
        session.rollback()
196
        raise
197
    finally:
198
        session.close()
199
200
201
def session_scoped(function):
202
    """Provide a session scope to a function.
203
204
    Can be used as a decorator like this:
205
206
    >>> @session_scoped
207
    ... def get_bind(session):
208
    ...     return session.get_bind()
209
    ...
210
    >>> get_bind()
211
    Engine(postgresql+psycopg2://egon:***@127.0.0.1:59734/egon-data)
212
213
    Note that the decorated function needs to accept a parameter named
214
    `session`, but is called without supplying a value for that parameter
215
    because the parameter's value will be filled in by `session_scoped`.
216
    Using this decorator allows saving an indentation level when defining
217
    such functions but it also has other usages.
218
    """
219
220
    @functools.wraps(function)
221
    def wrapped(*xs, **ks):
222
        with session_scope() as session:
223
            return function(session=session, *xs, **ks)
224
225
    return wrapped
226
227
228
def select_dataframe(sql, index_col=None, warning=True):
229
    """Select data from local database as pandas.DataFrame
230
231
    Parameters
232
    ----------
233
    sql : str
234
        SQL query to be executed.
235
    index_col : str, optional
236
        Column(s) to set as index(MultiIndex). The default is None.
237
238
    Returns
239
    -------
240
    df : pandas.DataFrame
241
        Data returned from SQL statement.
242
243
    """
244
245
    with access() as database:
246
        df = pd.read_sql(sql, database.connection, index_col=index_col)
247
248
    if df.size == 0 and warning is True:
249
        print(f"WARNING: No data returned by statement: \n {sql}")
250
251
    return df
252
253
254
def select_geodataframe(sql, index_col=None, geom_col="geom", epsg=3035):
255
    """Select data from local database as geopandas.GeoDataFrame
256
257
    Parameters
258
    ----------
259
    sql : str
260
        SQL query to be executed.
261
    index_col : str, optional
262
        Column(s) to set as index(MultiIndex). The default is None.
263
    geom_col : str, optional
264
        column name to convert to shapely geometries. The default is 'geom'.
265
    epsg : int, optional
266
        EPSG code specifying output projection. The default is 3035.
267
268
    Returns
269
    -------
270
    gdf : pandas.DataFrame
271
        Data returned from SQL statement.
272
273
    """
274
275
    with access() as database:
276
        gdf = gpd.read_postgis(
277
            sql, database.connection, index_col=index_col, geom_col=geom_col
278
        )
279
280
    if gdf.size == 0:
281
        print(f"WARNING: No data returned by statement: \n {sql}")
282
283
    else:
284
        gdf = gdf.to_crs(epsg=epsg)
285
286
    return gdf
287
288
289
def next_etrago_id(component):
290
    """Select next id value for components in etrago tables
291
292
    Parameters
293
    ----------
294
    component : str
295
        Name of component
296
297
    Returns
298
    -------
299
    next_id : int
300
        Next index value
301
302
    Notes
303
    -----
304
    To catch concurrent DB commits, consider to use
305
    :func:`check_db_unique_violation` instead.
306
    """
307
308
    if component == "transformer":
309
        id_column = "trafo_id"
310
    else:
311
        id_column = f"{component}_id"
312
313
    max_id = select_dataframe(
314
        f"""
315
        SELECT MAX({id_column}) FROM grid.egon_etrago_{component}
316
        """
317
    )["max"][0]
318
319
    if max_id:
320
        next_id = max_id + 1
321
    else:
322
        next_id = 1
323
324
    return next_id
325
326
327
def check_db_unique_violation(func):
328
    """Wrapper to catch psycopg's UniqueViolation errors during concurrent DB
329
    commits.
330
331
    Preferrably used with :func:`next_etrago_id`. Retries DB operation 10
332
    times before raising original exception.
333
334
    Can be used as a decorator like this:
335
336
    >>> @check_db_unique_violation
337
    ... def commit_something_to_database():
338
    ...     # commit something here
339
    ...    return
340
    ...
341
    >>> commit_something_to_database()  # doctest: +SKIP
342
343
    Examples
344
    --------
345
    Add new bus to eTraGo's bus table:
346
347
    >>> from egon.data import db
348
    >>> from egon.data.datasets.etrago_setup import EgonPfHvBus
349
    ...
350
    >>> @check_db_unique_violation
351
    ... def add_etrago_bus():
352
    ...     bus_id = db.next_etrago_id("bus")
353
    ...     with db.session_scope() as session:
354
    ...         emob_bus_id = db.next_etrago_id("bus")
355
    ...         session.add(
356
    ...             EgonPfHvBus(
357
    ...                 scn_name="eGon2035",
358
    ...                 bus_id=bus_id,
359
    ...                 v_nom=1,
360
    ...                 carrier="whatever",
361
    ...                 x=52,
362
    ...                 y=13,
363
    ...                 geom="<some_geom>"
364
    ...             )
365
    ...         )
366
    ...         session.commit()
367
    ...
368
    >>> add_etrago_bus()  # doctest: +SKIP
369
370
    Parameters
371
    ----------
372
373
    func: func
374
        Function to wrap
375
376
    Notes
377
    -----
378
    Background: using :func:`next_etrago_id` may cause trouble if tasks are
379
    executed simultaneously, cf.
380
    https://github.com/openego/eGon-data/issues/514
381
382
    Important: your function requires a way to escape the violation as the
383
    loop will not terminate until the error is resolved! In case of eTraGo
384
    tables you can use :func:`next_etrago_id`, see example above.
385
    """
386
387
    def commit(*args, **kwargs):
388
        unique_violation = True
389
        ret = None
390
        ctr = 0
391
        while unique_violation:
392
            try:
393
                ret = func(*args, **kwargs)
394
            except IntegrityError as e:
395
                if isinstance(e.orig, UniqueViolation):
396
                    print("Entry is not unique, retrying...")
397
                    ctr += 1
398
                    time.sleep(3)
399
                    if ctr > 10:
400
                        print("No success after 10 retries, exiting...")
401
                        raise e
402
                else:
403
                    raise e
404
            # ===== TESTING ON DEADLOCKS START =====
405
            except OperationalError as e:
406
                if isinstance(e.orig, DeadlockDetected):
407
                    print("Deadlock detected, retrying...")
408
                    ctr += 1
409
                    time.sleep(3)
410
                    if ctr > 10:
411
                        print("No success after 10 retries, exiting...")
412
                        raise e
413
            # ===== TESTING ON DEADLOCKS END =======
414
            else:
415
                unique_violation = False
416
        return ret
417
418
    return commit
419
420
421
def assign_gas_bus_id(dataframe, scn_name, carrier):
422
    """Assigns bus_ids to points according to location.
423
424
    The points are taken from the given `dataframe` and the geometries by
425
    which the `bus_id`s are assigned to them are taken from the
426
    `grid.egon_gas_voronoi` table.
427
428
    Parameters
429
    ----------
430
    dataframe : pandas.DataFrame
431
        DataFrame cointaining points
432
    scn_name : str
433
        Name of the scenario
434
    carrier : str
435
        Name of the carrier
436
437
    Returns
438
    -------
439
    res : pandas.DataFrame
440
        Dataframe including bus_id
441
    """
442
443
    voronoi = select_geodataframe(
444
        f"""
445
        SELECT bus_id, geom FROM grid.egon_gas_voronoi
446
        WHERE scn_name = '{scn_name}' AND carrier = '{carrier}';
447
        """,
448
        epsg=4326,
449
    )
450
451
    res = gpd.sjoin(dataframe, voronoi)
452
    res["bus"] = res["bus_id"]
453
    res = res.drop(columns=["index_right"])
454
455
    # Assert that all power plants have a bus_id
456
    assert (
457
        res.bus.notnull().all()
458
    ), f"Some points are not attached to a {carrier} bus."
459
460
    return res
461