Passed
Push — dev ( c9030d...1c6fbe )
by Stephan
02:06 queued 12s
created

data.db.engine_for()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 1
nop 1
1
from contextlib import contextmanager
2
import codecs
3
import functools
4
import os
5
import time
6
7
from psycopg2.errors import DeadlockDetected, UniqueViolation
8
from sqlalchemy import create_engine, text
9
from sqlalchemy.exc import IntegrityError, OperationalError
10
from sqlalchemy.orm import sessionmaker
11
import geopandas as gpd
12
import pandas as pd
13
14
from egon.data import config
15
16
17
def credentials():
18
    """Return local database connection parameters.
19
20
    Returns
21
    -------
22
    dict
23
        Complete DB connection information
24
    """
25
    translated = {
26
        "--database-name": "POSTGRES_DB",
27
        "--database-password": "POSTGRES_PASSWORD",
28
        "--database-host": "HOST",
29
        "--database-port": "PORT",
30
        "--database-user": "POSTGRES_USER",
31
    }
32
    configuration = config.settings()["egon-data"]
33
    update = {
34
        translated[flag]: configuration[flag]
35
        for flag in configuration
36
        if flag in translated
37
    }
38
    configuration.update(update)
39
    return configuration
40
41
42
@functools.lru_cache(maxsize=None)
43
def engine_for(pid):  # pylint: disable=unused-argument
44
    db_config = credentials()
45
    return create_engine(
46
        f"postgresql+psycopg2://{db_config['POSTGRES_USER']}:"
47
        f"{db_config['POSTGRES_PASSWORD']}@{db_config['HOST']}:"
48
        f"{db_config['PORT']}/{db_config['POSTGRES_DB']}",
49
        echo=False,
50
    )
51
52
53
def engine():
54
    """Engine for local database."""
55
    return engine_for(os.getpid())
56
57
58
def execute_sql(sql_string):
59
    """Execute a SQL expression given as string.
60
61
    The SQL expression passed as plain string is convert to a
62
    `sqlalchemy.sql.expression.TextClause`.
63
64
    Parameters
65
    ----------
66
    sql_string : str
67
        SQL expression
68
69
    """
70
    engine_local = engine()
71
72
    with engine_local.connect().execution_options(autocommit=True) as con:
73
        con.execute(text(sql_string))
74
75
76
def submit_comment(json, schema, table):
77
    """Add comment to table.
78
79
    We use `Open Energy Metadata <https://github.com/OpenEnergyPlatform/
80
    oemetadata/blob/develop/metadata/v141/metadata_key_description.md>`_
81
    standard for describing our data. Metadata is stored as JSON in the table
82
    comment.
83
84
    Parameters
85
    ----------
86
    json : str
87
        JSON string reflecting comment
88
    schema : str
89
        The target table's database schema
90
    table : str
91
        Database table on which to put the given comment
92
    """
93
    prefix_str = "COMMENT ON TABLE {0}.{1} IS ".format(schema, table)
94
95
    check_json_str = (
96
        "SELECT obj_description('{0}.{1}'::regclass)::json".format(
97
            schema, table
98
        )
99
    )
100
101
    execute_sql(prefix_str + json + ";")
102
103
    # Query table comment and cast it into JSON
104
    # The query throws an error if JSON is invalid
105
    execute_sql(check_json_str)
106
107
108
def execute_sql_script(script, encoding="utf-8-sig"):
109
    """Execute a SQL script given as a file name.
110
111
    Parameters
112
    ----------
113
    script : str
114
        Path of the SQL-script
115
    encoding : str
116
        Encoding which is used for the SQL file. The default is "utf-8-sig".
117
    Returns
118
    -------
119
    None.
120
121
    """
122
123
    with codecs.open(script, "r", encoding) as fd:
124
        sqlfile = fd.read()
125
126
    execute_sql(sqlfile)
127
128
129
@contextmanager
130
def session_scope():
131
    """Provide a transactional scope around a series of operations."""
132
    Session = sessionmaker(bind=engine())
133
    session = Session()
134
    try:
135
        yield session
136
        session.commit()
137
    except:  # noqa: E722 (This is OK, because of the immediate re-raise.)
138
        session.rollback()
139
        raise
140
    finally:
141
        session.close()
142
143
144
def session_scoped(function):
145
    """Provide a session scope to a function.
146
147
    Can be used as a decorator like this:
148
149
    >>> @session_scoped
150
    ... def get_bind(session):
151
    ...     return session.get_bind()
152
    ...
153
    >>> get_bind()
154
    Engine(postgresql+psycopg2://egon:***@127.0.0.1:59734/egon-data)
155
156
    Note that the decorated function needs to accept a parameter named
157
    `session`, but is called without supplying a value for that parameter
158
    because the parameter's value will be filled in by `session_scoped`.
159
    Using this decorator allows saving an indentation level when defining
160
    such functions but it also has other usages.
161
    """
162
163
    @functools.wraps(function)
164
    def wrapped(*xs, **ks):
165
        with session_scope() as session:
166
            return function(session=session, *xs, **ks)
167
168
    return wrapped
169
170
171
def select_dataframe(sql, index_col=None, warning=True):
172
    """Select data from local database as pandas.DataFrame
173
174
    Parameters
175
    ----------
176
    sql : str
177
        SQL query to be executed.
178
    index_col : str, optional
179
        Column(s) to set as index(MultiIndex). The default is None.
180
181
    Returns
182
    -------
183
    df : pandas.DataFrame
184
        Data returned from SQL statement.
185
186
    """
187
188
    df = pd.read_sql(sql, engine(), index_col=index_col)
189
190
    if df.size == 0 and warning is True:
191
        print(f"WARNING: No data returned by statement: \n {sql}")
192
193
    return df
194
195
196
def select_geodataframe(sql, index_col=None, geom_col="geom", epsg=3035):
197
    """Select data from local database as geopandas.GeoDataFrame
198
199
    Parameters
200
    ----------
201
    sql : str
202
        SQL query to be executed.
203
    index_col : str, optional
204
        Column(s) to set as index(MultiIndex). The default is None.
205
    geom_col : str, optional
206
        column name to convert to shapely geometries. The default is 'geom'.
207
    epsg : int, optional
208
        EPSG code specifying output projection. The default is 3035.
209
210
    Returns
211
    -------
212
    gdf : pandas.DataFrame
213
        Data returned from SQL statement.
214
215
    """
216
217
    gdf = gpd.read_postgis(
218
        sql, engine(), index_col=index_col, geom_col=geom_col
219
    )
220
221
    if gdf.size == 0:
222
        print(f"WARNING: No data returned by statement: \n {sql}")
223
224
    else:
225
        gdf = gdf.to_crs(epsg=epsg)
226
227
    return gdf
228
229
230
def next_etrago_id(component):
231
    """Select next id value for components in etrago tables
232
233
    Parameters
234
    ----------
235
    component : str
236
        Name of component
237
238
    Returns
239
    -------
240
    next_id : int
241
        Next index value
242
243
    Notes
244
    -----
245
    To catch concurrent DB commits, consider to use
246
    :func:`check_db_unique_violation` instead.
247
    """
248
249
    if component == "transformer":
250
        id_column = "trafo_id"
251
    else:
252
        id_column = f"{component}_id"
253
254
    max_id = select_dataframe(
255
        f"""
256
        SELECT MAX({id_column}) FROM grid.egon_etrago_{component}
257
        """
258
    )["max"][0]
259
260
    if max_id:
261
        next_id = max_id + 1
262
    else:
263
        next_id = 1
264
265
    return next_id
266
267
268
def check_db_unique_violation(func):
269
    """Wrapper to catch psycopg's UniqueViolation errors during concurrent DB
270
    commits.
271
272
    Preferrably used with :func:`next_etrago_id`. Retries DB operation 10
273
    times before raising original exception.
274
275
    Can be used as a decorator like this:
276
277
    >>> @check_db_unique_violation
278
    ... def commit_something_to_database():
279
    ...     # commit something here
280
    ...    return
281
    ...
282
    >>> commit_something_to_database()  # doctest: +SKIP
283
284
    Examples
285
    --------
286
    Add new bus to eTraGo's bus table:
287
288
    >>> from egon.data import db
289
    >>> from egon.data.datasets.etrago_setup import EgonPfHvBus
290
    ...
291
    >>> @check_db_unique_violation
292
    ... def add_etrago_bus():
293
    ...     bus_id = db.next_etrago_id("bus")
294
    ...     with db.session_scope() as session:
295
    ...         emob_bus_id = db.next_etrago_id("bus")
296
    ...         session.add(
297
    ...             EgonPfHvBus(
298
    ...                 scn_name="eGon2035",
299
    ...                 bus_id=bus_id,
300
    ...                 v_nom=1,
301
    ...                 carrier="whatever",
302
    ...                 x=52,
303
    ...                 y=13,
304
    ...                 geom="<some_geom>"
305
    ...             )
306
    ...         )
307
    ...         session.commit()
308
    ...
309
    >>> add_etrago_bus()  # doctest: +SKIP
310
311
    Parameters
312
    ----------
313
314
    func: func
315
        Function to wrap
316
317
    Notes
318
    -----
319
    Background: using :func:`next_etrago_id` may cause trouble if tasks are
320
    executed simultaneously, cf.
321
    https://github.com/openego/eGon-data/issues/514
322
323
    Important: your function requires a way to escape the violation as the
324
    loop will not terminate until the error is resolved! In case of eTraGo
325
    tables you can use :func:`next_etrago_id`, see example above.
326
    """
327
328
    def commit(*args, **kwargs):
329
        unique_violation = True
330
        ret = None
331
        ctr = 0
332
        while unique_violation:
333
            try:
334
                ret = func(*args, **kwargs)
335
            except IntegrityError as e:
336
                if isinstance(e.orig, UniqueViolation):
337
                    print("Entry is not unique, retrying...")
338
                    ctr += 1
339
                    time.sleep(3)
340
                    if ctr > 10:
341
                        print("No success after 10 retries, exiting...")
342
                        raise e
343
                else:
344
                    raise e
345
            # ===== TESTING ON DEADLOCKS START =====
346
            except OperationalError as e:
347
                if isinstance(e.orig, DeadlockDetected):
348
                    print("Deadlock detected, retrying...")
349
                    ctr += 1
350
                    time.sleep(3)
351
                    if ctr > 10:
352
                        print("No success after 10 retries, exiting...")
353
                        raise e
354
            # ===== TESTING ON DEADLOCKS END =======
355
            else:
356
                unique_violation = False
357
        return ret
358
359
    return commit
360
361
362
def assign_gas_bus_id(dataframe, scn_name, carrier):
363
    """Assign `bus_id`s to points according to location.
364
365
    The points are taken from the given `dataframe` and the geometries by
366
    which the `bus_id`s are assigned to them are taken from the
367
    `grid.egon_gas_voronoi` table.
368
369
    Parameters
370
    ----------
371
    dataframe : pandas.DataFrame
372
        DataFrame cointaining points
373
    scn_name : str
374
        Name of the scenario
375
    carrier : str
376
        Name of the carrier
377
378
    Returns
379
    -------
380
    res : pandas.DataFrame
381
        Dataframe including bus_id
382
    """
383
384
    voronoi = select_geodataframe(
385
        f"""
386
        SELECT bus_id, geom FROM grid.egon_gas_voronoi
387
        WHERE scn_name = '{scn_name}' AND carrier = '{carrier}';
388
        """,
389
        epsg=4326,
390
    )
391
392
    res = gpd.sjoin(dataframe, voronoi)
393
    res["bus"] = res["bus_id"]
394
    res = res.drop(columns=["index_right"])
395
396
    # Assert that all power plants have a bus_id
397
    assert (
398
        res.bus.notnull().all()
399
    ), f"Some points are not attached to a {carrier} bus."
400
401
    return res
402