Passed
Pull Request — dev (#905)
by
unknown
01:46
created

write_table_to_postgres()   A

Complexity

Conditions 4

Size

Total Lines 56
Code Lines 25

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 25
dl 0
loc 56
rs 9.28
c 0
b 0
f 0
cc 4
nop 5

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
from io import StringIO
2
import csv
3
import time
4
5
from shapely.geometry import Point
6
import geopandas as gpd
7
import numpy as np
8
import pandas as pd
9
10
from egon.data import db, logger
11
12
engine = db.engine()
13
14
15
def timeit(func):
16
    """
17
    Decorator for measuring function's running time.
18
    """
19
20
    def measure_time(*args, **kw):
21
        start_time = time.time()
22
        result = func(*args, **kw)
23
        print(
24
            "Processing time of %s(): %.2f seconds."
25
            % (func.__qualname__, time.time() - start_time)
26
        )
27
        return result
28
29
    return measure_time
30
31
32
def random_point_in_square(geom, tol):
33
    """
34
    Generate a random point within a square
35
36
    Parameters
37
    ----------
38
    geom: gpd.Series
39
        Geometries of square
40
    tol: float
41
        tolerance to square bounds
42
43
    Returns
44
    -------
45
    points: gpd.Series
46
        Series of random points
47
    """
48
    # cell bounds - half edge_length to not build buildings on the cell border
49
    xmin = geom.bounds["minx"] + tol / 2
50
    xmax = geom.bounds["maxx"] - tol / 2
51
    ymin = geom.bounds["miny"] + tol / 2
52
    ymax = geom.bounds["maxy"] - tol / 2
53
54
    # generate random coordinates within bounds - half edge_length
55
    x = (xmax - xmin) * np.random.rand(geom.shape[0]) + xmin
56
    y = (ymax - ymin) * np.random.rand(geom.shape[0]) + ymin
57
58
    points = pd.Series([Point(cords) for cords in zip(x, y)])
59
    points = gpd.GeoSeries(points, crs="epsg:3035")
60
61
    return points
62
63
64
# distribute amenities evenly
65
def specific_int_until_sum(s_sum, i_int):
66
    """
67
    Generate list `i_int` summing to `s_sum`. Last value will be <= `i_int`
68
    """
69
    list_i = [] if [s_sum % i_int] == [0] else [s_sum % i_int]
70
    list_i += s_sum // i_int * [i_int]
71
    return list_i
72
73
74
def random_ints_until_sum(s_sum, m_max):
75
    """
76
    Generate non-negative random integers < `m_max` summing to `s_sum`.
77
    """
78
    list_r = []
79
    while s_sum > 0:
80
        r = np.random.randint(1, m_max + 1)
81
        r = r if r <= m_max and r < s_sum else s_sum
82
        list_r.append(r)
83
        s_sum -= r
84
    return list_r
85
86
87
def write_table_to_postgis(gdf, table, engine=db.engine(), drop=True):
88
    """
89
    Helper function to append df data to table in db. Only predefined columns
90
    are passed. Error will raise if column is missing. Dtype of columns are
91
    taken from table definition.
92
93
    Parameters
94
    ----------
95
    gdf: gpd.DataFrame
96
        Table of data
97
    table: declarative_base
98
        Metadata of db table to export to
99
    engine:
100
        connection to database db.engine()
101
    drop: bool
102
        Drop table before appending
103
104
    """
105
106
    # Only take in db table defined columns
107
    columns = [column.key for column in table.__table__.columns]
108
    gdf = gdf.loc[:, columns]
109
110
    if drop:
111
        table.__table__.drop(bind=engine, checkfirst=True)
112
        table.__table__.create(bind=engine)
113
114
    dtypes = {
115
        i: table.__table__.columns[i].type
116
        for i in table.__table__.columns.keys()
117
    }
118
119
    # Write new buildings incl coord into db
120
    gdf.to_postgis(
121
        name=table.__tablename__,
122
        con=engine,
123
        if_exists="append",
124
        schema=table.__table_args__["schema"],
125
        dtype=dtypes,
126
    )
127
128
129
def psql_insert_copy(table, conn, keys, data_iter):
130
    """
131
    Execute SQL statement inserting data
132
133
    Parameters
134
    ----------
135
    table : pandas.io.sql.SQLTable
136
    conn : sqlalchemy.engine.Engine or sqlalchemy.engine.Connection
137
    keys : list of str
138
        Column names
139
    data_iter : Iterable that iterates the values to be inserted
140
    """
141
    # gets a DBAPI connection that can provide a cursor
142
    dbapi_conn = conn.connection
143
    with dbapi_conn.cursor() as cur:
144
        s_buf = StringIO()
145
        writer = csv.writer(s_buf)
146
        writer.writerows(data_iter)
147
        s_buf.seek(0)
148
149
        columns = ", ".join('"{}"'.format(k) for k in keys)
150
        if table.schema:
151
            table_name = "{}.{}".format(table.schema, table.name)
152
        else:
153
            table_name = table.name
154
155
        sql = "COPY {} ({}) FROM STDIN WITH CSV".format(table_name, columns)
156
        cur.copy_expert(sql=sql, file=s_buf)
157
158
159
def write_table_to_postgres(
160
    df, db_table, drop=False, index=False, if_exists="append"
161
):
162
    """
163
    Helper function to append df data to table in db. Fast string-copy is used.
164
    Only predefined columns are passed. If column is missing in dataframe a
165
    warning is logged. Dtypes of columns are taken from table definition. The
166
    writing process happens in a scoped session.
167
168
    Parameters
169
    ----------
170
    df: pd.DataFrame
171
        Table of data
172
    db_table: declarative_base
173
        Metadata of db table to export to
174
    drop: boolean, default False
175
        Drop db-table before appending
176
    index: boolean, default False
177
        Write DataFrame index as a column.
178
    if_exists: {'fail', 'replace', 'append'}, default 'append'
179
        - fail: If table exists, do nothing.
180
        - replace: If table exists, drop it, recreate it, and insert data.
181
        - append: If table exists, insert data. Create if does not exist.
182
183
    """
184
    logger.info("Write table to db")
185
    # Only take in db table defined columns and dtypes
186
    columns = {
187
        column.key: column.type for column in db_table.__table__.columns
188
    }
189
190
    # Take only the columns defined in class
191
    # pandas raises an error if column is missing
192
    try:
193
        df = df.loc[:, columns.keys()]
194
    except KeyError:
195
        same = df.columns.intersection(columns.keys())
196
        missing = same.symmetric_difference(df.columns)
197
        logger.warning(f"Columns: {missing.values} missing!")
198
        df = df.loc[:, same]
199
200
    if drop:
201
        db_table.__table__.drop(bind=engine, checkfirst=True)
202
        db_table.__table__.create(bind=engine)
203
    else:
204
        db_table.__table__.create(bind=engine, checkfirst=True)
205
206
    with db.session_scope() as session:
207
        df.to_sql(
208
            name=db_table.__table__.name,
209
            schema=db_table.__table__.schema,
210
            con=session.connection(),
211
            if_exists=if_exists,
212
            index=index,
213
            method=psql_insert_copy,
214
            dtype=columns,
215
        )
216