Passed
Push — dev ( d02171...8ae202 )
by
unknown
02:10 queued 12s
created

data.datasets.Dataset.update()   A

Complexity

Conditions 2

Size

Total Lines 23
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 18
nop 2
dl 0
loc 23
rs 9.5
c 0
b 0
f 0
1
"""The API for configuring datasets."""
2
3
from __future__ import annotations
4
5
from collections import abc
6
from dataclasses import dataclass
7
from functools import reduce
8
from typing import Callable, Iterable, Set, Tuple, Union
9
import re
10
11
from airflow.operators import BaseOperator as Operator
12
from airflow.operators.python_operator import PythonOperator
13
from sqlalchemy import Column, ForeignKey, Integer, String, Table, orm, tuple_
14
from sqlalchemy.ext.declarative import declarative_base
15
16
from egon.data import db, logger
17
18
Base = declarative_base()
19
SCHEMA = "metadata"
20
21
22
def setup():
23
    """Create the database structure for storing dataset information."""
24
    # TODO: Move this into a task generating the initial database structure.
25
    db.execute_sql(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA};")
26
    Model.__table__.create(bind=db.engine(), checkfirst=True)
27
    DependencyGraph.create(bind=db.engine(), checkfirst=True)
28
29
30
# TODO: Figure out how to use a mapped class as an association table.
31
#
32
# Trying it out, I ran into quite a few problems and didn't have time to do
33
# further research. The benefits are mostly just convenience, so it doesn't
34
# have a high priority. But I'd like to keep the code I started out with to
35
# have a starting point for me or anybody else trying again.
36
#
37
# class DependencyGraph(Base):
38
#     __tablename__ = "dependency_graph"
39
#     __table_args__ = {"schema": SCHEMA}
40
#     dependency_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
41
#     dependent_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
42
43
DependencyGraph = Table(
44
    "dependency_graph",
45
    Base.metadata,
46
    Column(
47
        "dependency_id",
48
        Integer,
49
        ForeignKey(f"{SCHEMA}.datasets.id"),
50
        primary_key=True,
51
    ),
52
    Column(
53
        "dependent_id",
54
        Integer,
55
        ForeignKey(f"{SCHEMA}.datasets.id"),
56
        primary_key=True,
57
    ),
58
    schema=SCHEMA,
59
)
60
61
62
class Model(Base):
63
    __tablename__ = "datasets"
64
    __table_args__ = {"schema": SCHEMA}
65
    id = Column(Integer, primary_key=True)
66
    name = Column(String, unique=True, nullable=False)
67
    version = Column(String, nullable=False)
68
    epoch = Column(Integer, default=0)
69
    dependencies = orm.relationship(
70
        "Model",
71
        secondary=DependencyGraph,
72
        primaryjoin=id == DependencyGraph.c.dependent_id,
73
        secondaryjoin=id == DependencyGraph.c.dependency_id,
74
        backref=orm.backref("dependents", cascade="all, delete"),
75
    )
76
77
78
#: A :class:`Task` is an Airflow :class:`Operator` or any
79
#: :class:`Callable <typing.Callable>` taking no arguments and returning
80
#: :obj:`None`. :class:`Callables <typing.Callable>` will be converted
81
#: to :class:`Operators <Operator>` by wrapping them in a
82
#: :class:`PythonOperator` and setting the :obj:`~PythonOperator.task_id`
83
#: to the :class:`Callable <typing.Callable>`'s
84
#: :obj:`~definition.__name__`, with underscores replaced with hyphens.
85
#: If the :class:`Callable <typing.Callable>`'s `__module__`__ attribute
86
#: contains the string :obj:`"egon.data.datasets."`, the
87
#: :obj:`~PythonOperator.task_id` is also prefixed with the module name,
88
#: followed by a dot and with :obj:`"egon.data.datasets."` removed.
89
#:
90
#: __ https://docs.python.org/3/reference/datamodel.html#index-34
91
Task = Union[Callable[[], None], Operator]
92
#: A graph of tasks is, in its simplest form, just a single node, i.e. a
93
#: single :class:`Task`. More complex graphs can be specified by nesting
94
#: :class:`sets <builtins.set>` and :class:`tuples <builtins.tuple>` of
95
#: :class:`TaskGraphs <TaskGraph>`. A set of :class:`TaskGraphs
96
#: <TaskGraph>` means that they are unordered and can be
97
#: executed in parallel. A :class:`tuple` specifies an implicit ordering so
98
#: a :class:`tuple` of :class:`TaskGraphs <TaskGraph>` will be executed
99
#: sequentially in the given order.
100
TaskGraph = Union[Task, Set["TaskGraph"], Tuple["TaskGraph", ...]]
101
102
103
def prefix(o):
104
    module = o.__module__
105
    parent = f"{__name__}."
106
    return f"{module.replace(parent, '')}." if parent in module else ""
107
108
109
@dataclass
110
class Tasks(dict):
111
    first: Set[Task]
112
    last: Set[Task]
113
    graph: TaskGraph = ()
114
115
    def __init__(self, graph: TaskGraph):
116
        """Connect multiple tasks into a potentially complex graph.
117
118
        Parses a :class:`TaskGraph` into a :class:`Tasks` object.
119
        """
120
        if isinstance(graph, Callable):
121
            graph = PythonOperator(
122
                task_id=f"{prefix(graph)}{graph.__name__.replace('_', '-')}",
123
                python_callable=graph,
124
            )
125
        self.graph = graph
126
        if isinstance(graph, Operator):
127
            self.first = {graph}
128
            self.last = {graph}
129
            self[graph.task_id] = graph
130
        elif isinstance(graph, abc.Sized) and len(graph) == 0:
131
            self.first = {}
132
            self.last = {}
133
        elif isinstance(graph, abc.Set):
134
            results = [Tasks(subtasks) for subtasks in graph]
135
            self.first = {task for result in results for task in result.first}
136
            self.last = {task for result in results for task in result.last}
137
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
138
            self.graph = set(tasks.graph for tasks in results)
139
        elif isinstance(graph, tuple):
140
            results = [Tasks(subtasks) for subtasks in graph]
141
            for (left, right) in zip(results[:-1], results[1:]):
142
                for last in left.last:
143
                    for first in right.first:
144
                        last.set_downstream(first)
145
            self.first = results[0].first
146
            self.last = results[-1].last
147
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
148
            self.graph = tuple(tasks.graph for tasks in results)
149
        else:
150
            raise (
151
                TypeError(
152
                    "`egon.data.datasets.Tasks` got an argument of type:\n\n"
153
                    f"  {type(graph)}\n\n"
154
                    "where only `Task`s, `Set`s and `Tuple`s are allowed."
155
                )
156
            )
157
158
159
@dataclass
160
class Dataset:
161
    #: The name of the Dataset
162
    name: str
163
    #: The :class:`Dataset`'s version. Can be anything from a simple
164
    #: semantic versioning string like "2.1.3", to a more complex
165
    #: string, like for example "2021-01-01.schleswig-holstein.0" for
166
    #: OpenStreetMap data.
167
    #: Note that the latter encodes the :class:`Dataset`'s date, region
168
    #: and a sequential number in case the data changes without the date
169
    #: or region changing, for example due to implementation changes.
170
    version: str
171
    #: The first task(s) of this :class:`Dataset` will be marked as
172
    #: downstream of any of the listed dependencies. In case of bare
173
    #: :class:`Task`, a direct link will be created whereas for a
174
    #: :class:`Dataset` the link will be made to all of its last tasks.
175
    dependencies: Iterable[Union[Dataset, Task]] = ()
176
    #: The tasks of this :class:`Dataset`. A :class:`TaskGraph` will
177
    #: automatically be converted to :class:`Tasks`.
178
    tasks: Union[Tasks, TaskGraph] = ()
179
180
    def check_version(self, after_execution=()):
181
        def skip_task(task, *xs, **ks):
182
            with db.session_scope() as session:
183
                datasets = session.query(Model).filter_by(name=self.name).all()
184
                if self.version in [
185
                    ds.version for ds in datasets
186
                ] and not re.search(r"\.dev$", self.version):
187
                    logger.info(
188
                        f"Dataset '{self.name}' version '{self.version}'"
189
                        f" already executed. Skipping."
190
                    )
191
                else:
192
                    for ds in datasets:
193
                        session.delete(ds)
194
                    result = super(type(task), task).execute(*xs, **ks)
195
                    for function in after_execution:
196
                        function(session)
197
                    return result
198
199
        return skip_task
200
201
    def update(self, session):
202
        dataset = Model(name=self.name, version=self.version)
203
        dependencies = (
204
            session.query(Model)
205
            .filter(
206
                tuple_(Model.name, Model.version).in_(
207
                    [
208
                        (dataset.name, dataset.version)
209
                        for dependency in self.dependencies
210
                        if isinstance(dependency, Dataset)
211
                        or hasattr(dependency, "dataset")
212
                        for dataset in [
213
                            dependency.dataset
214
                            if isinstance(dependency, Operator)
215
                            else dependency
216
                        ]
217
                    ]
218
                )
219
            )
220
            .all()
221
        )
222
        dataset.dependencies = dependencies
223
        session.add(dataset)
224
225
    def __post_init__(self):
226
        self.dependencies = list(self.dependencies)
227
        if not isinstance(self.tasks, Tasks):
228
            self.tasks = Tasks(self.tasks)
229
        if len(self.tasks.last) > 1:
230
            # Explicitly create single final task, because we can't know
231
            # which of the multiple tasks finishes last.
232
            name = prefix(self)
233
            name = f"{name if name else f'{self.__module__}.'}{self.name}."
234
            update_version = PythonOperator(
235
                task_id=f"{name}update-version",
236
                # Do nothing, because updating will be added later.
237
                python_callable=lambda *xs, **ks: None,
238
            )
239
            self.tasks = Tasks((self.tasks.graph, update_version))
240
        # Due to the `if`-block above, there'll now always be exactly
241
        # one task in `self.tasks.last` which the next line just
242
        # selects.
243
        last = list(self.tasks.last)[0]
244
        for task in self.tasks.values():
245
            task.dataset = self
246
            cls = task.__class__
247
            versioned = type(
248
                f"{self.name[0].upper()}{self.name[1:]} (versioned)",
249
                (cls,),
250
                {
251
                    "execute": self.check_version(
252
                        after_execution=[self.update] if task is last else []
253
                    )
254
                },
255
            )
256
            task.__class__ = versioned
257
258
        predecessors = [
259
            task
260
            for dataset in self.dependencies
261
            if isinstance(dataset, Dataset)
262
            for task in dataset.tasks.last
263
        ] + [task for task in self.dependencies if isinstance(task, Operator)]
264
        for p in predecessors:
265
            for first in self.tasks.first:
266
                p.set_downstream(first)
267