Passed
Pull Request — dev (#840)
by
unknown
01:32
created

data.db.check_db_unique_violation()   C

Complexity

Conditions 9

Size

Total Lines 92
Code Lines 28

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 28
dl 0
loc 92
rs 6.6666
c 0
b 0
f 0
cc 9
nop 1

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
from contextlib import contextmanager
2
import codecs
3
import functools
4
import time
5
6
from psycopg2.errors import DeadlockDetected, UniqueViolation
7
from sqlalchemy import create_engine, text
8
from sqlalchemy.exc import IntegrityError, OperationalError
9
from sqlalchemy.orm import sessionmaker
10
import geopandas as gpd
11
import pandas as pd
12
13
from egon.data import config
14
15
16
def credentials():
17
    """Return local database connection parameters.
18
19
    Returns
20
    -------
21
    dict
22
        Complete DB connection information
23
    """
24
    translated = {
25
        "--database-name": "POSTGRES_DB",
26
        "--database-password": "POSTGRES_PASSWORD",
27
        "--database-host": "HOST",
28
        "--database-port": "PORT",
29
        "--database-user": "POSTGRES_USER",
30
    }
31
    configuration = config.settings()["egon-data"]
32
    update = {
33
        translated[flag]: configuration[flag]
34
        for flag in configuration
35
        if flag in translated
36
    }
37
    configuration.update(update)
38
    return configuration
39
40
41
def engine():
42
    """Engine for local database."""
43
    db_config = credentials()
44
    return create_engine(
45
        f"postgresql+psycopg2://{db_config['POSTGRES_USER']}:"
46
        f"{db_config['POSTGRES_PASSWORD']}@{db_config['HOST']}:"
47
        f"{db_config['PORT']}/{db_config['POSTGRES_DB']}",
48
        echo=False,
49
    )
50
51
52
def execute_sql(sql_string):
53
    """Execute a SQL expression given as string.
54
55
    The SQL expression passed as plain string is convert to a
56
    `sqlalchemy.sql.expression.TextClause`.
57
58
    Parameters
59
    ----------
60
    sql_string : str
61
        SQL expression
62
63
    """
64
    engine_local = engine()
65
66
    with engine_local.connect().execution_options(autocommit=True) as con:
67
        con.execute(text(sql_string))
68
69
70
def submit_comment(json, schema, table):
71
    """Add comment to table.
72
73
    We use `Open Energy Metadata <https://github.com/OpenEnergyPlatform/
74
    oemetadata/blob/develop/metadata/v141/metadata_key_description.md>`_
75
    standard for describing our data. Metadata is stored as JSON in the table
76
    comment.
77
78
    Parameters
79
    ----------
80
    json : str
81
        JSON string reflecting comment
82
    schema : str
83
        The target table's database schema
84
    table : str
85
        Database table on which to put the given comment
86
    """
87
    prefix_str = "COMMENT ON TABLE {0}.{1} IS ".format(schema, table)
88
89
    check_json_str = (
90
        "SELECT obj_description('{0}.{1}'::regclass)::json".format(
91
            schema, table
92
        )
93
    )
94
95
    execute_sql(prefix_str + json + ";")
96
97
    # Query table comment and cast it into JSON
98
    # The query throws an error if JSON is invalid
99
    execute_sql(check_json_str)
100
101
102
def execute_sql_script(script, encoding="utf-8-sig"):
103
    """Execute a SQL script given as a file name.
104
105
    Parameters
106
    ----------
107
    script : str
108
        Path of the SQL-script
109
    encoding : str
110
        Encoding which is used for the SQL file. The default is "utf-8-sig".
111
    Returns
112
    -------
113
    None.
114
115
    """
116
117
    with codecs.open(script, "r", encoding) as fd:
118
        sqlfile = fd.read()
119
120
    execute_sql(sqlfile)
121
122
123
@contextmanager
124
def session_scope():
125
    """Provide a transactional scope around a series of operations."""
126
    Session = sessionmaker(bind=engine())
127
    session = Session()
128
    try:
129
        yield session
130
        session.commit()
131
    except:
132
        session.rollback()
133
        raise
134
    finally:
135
        session.close()
136
137
138
def session_scoped(function):
139
    """Provide a session scope to a function.
140
141
    Can be used as a decorator like this:
142
143
    >>> @session_scoped
144
    ... def get_bind(session):
145
    ...     return session.get_bind()
146
    ...
147
    >>> get_bind()
148
    Engine(postgresql+psycopg2://egon:***@127.0.0.1:59734/egon-data)
149
150
    Note that the decorated function needs to accept a parameter named
151
    `session`, but is called without supplying a value for that parameter
152
    because the parameter's value will be filled in by `session_scoped`.
153
    Using this decorator allows saving an indentation level when defining
154
    such functions but it also has other usages.
155
    """
156
157
    @functools.wraps(function)
158
    def wrapped(*xs, **ks):
159
        with session_scope() as session:
160
            return function(session=session, *xs, **ks)
161
162
    return wrapped
163
164
165
def select_dataframe(sql, index_col=None, warning=True):
166
    """Select data from local database as pandas.DataFrame
167
168
    Parameters
169
    ----------
170
    sql : str
171
        SQL query to be executed.
172
    index_col : str, optional
173
        Column(s) to set as index(MultiIndex). The default is None.
174
175
    Returns
176
    -------
177
    df : pandas.DataFrame
178
        Data returned from SQL statement.
179
180
    """
181
182
    df = pd.read_sql(sql, engine(), index_col=index_col)
183
184
    if df.size == 0 and warning is True:
185
        print(f"WARNING: No data returned by statement: \n {sql}")
186
187
    return df
188
189
190
def select_geodataframe(sql, index_col=None, geom_col="geom", epsg=3035):
191
    """Select data from local database as geopandas.GeoDataFrame
192
193
    Parameters
194
    ----------
195
    sql : str
196
        SQL query to be executed.
197
    index_col : str, optional
198
        Column(s) to set as index(MultiIndex). The default is None.
199
    geom_col : str, optional
200
        column name to convert to shapely geometries. The default is 'geom'.
201
    epsg : int, optional
202
        EPSG code specifying output projection. The default is 3035.
203
204
    Returns
205
    -------
206
    gdf : pandas.DataFrame
207
        Data returned from SQL statement.
208
209
    """
210
211
    gdf = gpd.read_postgis(
212
        sql, engine(), index_col=index_col, geom_col=geom_col
213
    )
214
215
    if gdf.size == 0:
216
        print(f"WARNING: No data returned by statement: \n {sql}")
217
218
    else:
219
        gdf = gdf.to_crs(epsg=epsg)
220
221
    return gdf
222
223
224
def next_etrago_id(component):
225
    """Select next id value for components in etrago tables
226
227
    Parameters
228
    ----------
229
    component : str
230
        Name of component
231
232
    Returns
233
    -------
234
    next_id : int
235
        Next index value
236
237
    Notes
238
    -----
239
    To catch concurrent DB commits, consider to use
240
    :func:`check_db_unique_violation` instead.
241
    """
242
243
    if component == "transformer":
244
        id_column = "trafo_id"
245
    else:
246
        id_column = f"{component}_id"
247
248
    max_id = select_dataframe(
249
        f"""
250
        SELECT MAX({id_column}) FROM grid.egon_etrago_{component}
251
        """
252
    )["max"][0]
253
254
    if max_id:
255
        next_id = max_id + 1
256
    else:
257
        next_id = 1
258
259
    return next_id
260
261
262
def check_db_unique_violation(func):
263
    """Wrapper to catch psycopg's UniqueViolation errors during concurrent DB
264
    commits.
265
266
    Preferrably used with :func:`next_etrago_id`. Retries DB operation 10
267
    times before raising original exception.
268
269
    Can be used as a decorator like this:
270
271
    >>> @check_db_unique_violation
272
    ... def commit_something_to_database():
273
    ...     # commit something here
274
    ...    return
275
    ...
276
    >>> commit_something_to_database()  # doctest: +SKIP
277
278
    Examples
279
    --------
280
    Add new bus to eTraGo's bus table:
281
282
    >>> from egon.data import db
283
    >>> from egon.data.datasets.etrago_setup import EgonPfHvBus
284
    ...
285
    >>> @check_db_unique_violation
286
    ... def add_etrago_bus():
287
    ...     bus_id = db.next_etrago_id("bus")
288
    ...     with db.session_scope() as session:
289
    ...         emob_bus_id = db.next_etrago_id("bus")
290
    ...         session.add(
291
    ...             EgonPfHvBus(
292
    ...                 scn_name="eGon2035",
293
    ...                 bus_id=bus_id,
294
    ...                 v_nom=1,
295
    ...                 carrier="whatever",
296
    ...                 x=52,
297
    ...                 y=13,
298
    ...                 geom="<some_geom>"
299
    ...             )
300
    ...         )
301
    ...         session.commit()
302
    ...
303
    >>> add_etrago_bus()  # doctest: +SKIP
304
305
    Parameters
306
    ----------
307
308
    func: func
309
        Function to wrap
310
311
    Notes
312
    -----
313
    Background: using :func:`next_etrago_id` may cause trouble if tasks are
314
    executed simultaneously, cf.
315
    https://github.com/openego/eGon-data/issues/514
316
317
    Important: your function requires a way to escape the violation as the
318
    loop will not terminate until the error is resolved! In case of eTraGo
319
    tables you can use :func:`next_etrago_id`, see example above.
320
    """
321
322
    def commit(*args, **kwargs):
323
        unique_violation = True
324
        ret = None
325
        ctr = 0
326
        while unique_violation:
327
            try:
328
                ret = func(*args, **kwargs)
329
            except IntegrityError as e:
330
                if isinstance(e.orig, UniqueViolation):
331
                    print("Entry is not unique, retrying...")
332
                    ctr += 1
333
                    time.sleep(3)
334
                    if ctr > 10:
335
                        print("No success after 10 retries, exiting...")
336
                        raise e
337
                else:
338
                    raise e
339
            # ===== TESTING ON DEADLOCKS START =====
340
            except OperationalError as e:
341
                if isinstance(e.orig, DeadlockDetected):
342
                    print("Deadlock detected, retrying...")
343
                    ctr += 1
344
                    time.sleep(3)
345
                    if ctr > 10:
346
                        print("No success after 10 retries, exiting...")
347
                        raise e
348
            # ===== TESTING ON DEADLOCKS END =======
349
            else:
350
                unique_violation = False
351
        return ret
352
353
    return commit
354
355
356
def assign_gas_bus_id(dataframe, scn_name, carrier):
357
    """Assigns bus_ids to points (contained in a dataframe) according to location
358
359
    Parameters
360
    ----------
361
    dataframe : pandas.DataFrame
362
        DataFrame cointaining points
363
    scn_name : str
364
        Name of the scenario
365
    carrier : str
366
        Name of the carrier
367
368
    Returns
369
    -------
370
    res : pandas.DataFrame
371
        Dataframe including bus_id
372
    """
373
374
    voronoi = select_geodataframe(
375
        f"""
376
        SELECT bus_id, geom FROM grid.egon_gas_voronoi
377
        WHERE scn_name = '{scn_name}' AND carrier = '{carrier}';
378
        """,
379
        epsg=4326,
380
    )
381
382
    res = gpd.sjoin(dataframe, voronoi)
383
    res["bus"] = res["bus_id"]
384
    res = res.drop(columns=["index_right"])
385
386
    # Assert that all power plants have a bus_id
387
    assert (
388
        res.bus.notnull().all()
389
    ), f"Some points are not attached to a {carrier} bus."
390
391
    return res
392