Completed
Push — dev ( 783b7a...fb1b84 )
by Stephan
21s queued 12s
created

data.datasets.Dataset.check_version()   B

Complexity

Conditions 6

Size

Total Lines 20
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 16
nop 2
dl 0
loc 20
rs 8.6666
c 0
b 0
f 0
1
"""The API for configuring datasets."""
2
3
from __future__ import annotations
4
5
from collections import abc
6
from dataclasses import dataclass
7
from functools import reduce
8
from typing import Callable, Iterable, Set, Tuple, Union
9
import re
10
11
from airflow import DAG
12
from airflow.operators import BaseOperator as Operator
13
from airflow.operators.python_operator import PythonOperator
14
from sqlalchemy import Column, ForeignKey, Integer, String, Table, orm, tuple_
15
from sqlalchemy.ext.declarative import declarative_base
16
17
from egon.data import db, logger
18
19
Base = declarative_base()
20
SCHEMA = "metadata"
21
22
23
def setup():
24
    """Create the database structure for storing dataset information."""
25
    # TODO: Move this into a task generating the initial database structure.
26
    db.execute_sql(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA};")
27
    Model.__table__.create(bind=db.engine(), checkfirst=True)
28
    DependencyGraph.create(bind=db.engine(), checkfirst=True)
29
30
31
# TODO: Figure out how to use a mapped class as an association table.
32
#
33
# Trying it out, I ran into quite a few problems and didn't have time to do
34
# further research. The benefits are mostly just convenience, so it doesn't
35
# have a high priority. But I'd like to keep the code I started out with to
36
# have a starting point for me or anybody else trying again.
37
#
38
# class DependencyGraph(Base):
39
#     __tablename__ = "dependency_graph"
40
#     __table_args__ = {"schema": SCHEMA}
41
#     dependency_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
42
#     dependent_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
43
44
DependencyGraph = Table(
45
    "dependency_graph",
46
    Base.metadata,
47
    Column(
48
        "dependency_id",
49
        Integer,
50
        ForeignKey(f"{SCHEMA}.datasets.id"),
51
        primary_key=True,
52
    ),
53
    Column(
54
        "dependent_id",
55
        Integer,
56
        ForeignKey(f"{SCHEMA}.datasets.id"),
57
        primary_key=True,
58
    ),
59
    schema=SCHEMA,
60
)
61
62
63
class Model(Base):
64
    __tablename__ = "datasets"
65
    __table_args__ = {"schema": SCHEMA}
66
    id = Column(Integer, primary_key=True)
67
    name = Column(String, unique=True, nullable=False)
68
    version = Column(String, nullable=False)
69
    epoch = Column(Integer, default=0)
70
    dependencies = orm.relationship(
71
        "Model",
72
        secondary=DependencyGraph,
73
        primaryjoin=id == DependencyGraph.c.dependent_id,
74
        secondaryjoin=id == DependencyGraph.c.dependency_id,
75
        backref=orm.backref("dependents", cascade="all, delete"),
76
    )
77
78
79
#: A :class:`Task` is an Airflow :class:`Operator` or any
80
#: :class:`Callable <typing.Callable>` taking no arguments and returning
81
#: :obj:`None`. :class:`Callables <typing.Callable>` will be converted
82
#: to :class:`Operators <Operator>` by wrapping them in a
83
#: :class:`PythonOperator` and setting the :obj:`~PythonOperator.task_id`
84
#: to the :class:`Callable <typing.Callable>`'s
85
#: :obj:`~definition.__name__`, with underscores replaced with hyphens.
86
#: If the :class:`Callable <typing.Callable>`'s `__module__`__ attribute
87
#: contains the string :obj:`"egon.data.datasets."`, the
88
#: :obj:`~PythonOperator.task_id` is also prefixed with the module name,
89
#: followed by a dot and with :obj:`"egon.data.datasets."` removed.
90
#:
91
#: __ https://docs.python.org/3/reference/datamodel.html#index-34
92
Task = Union[Callable[[], None], Operator]
93
#: A graph of tasks is, in its simplest form, just a single node, i.e. a
94
#: single :class:`Task`. More complex graphs can be specified by nesting
95
#: :class:`sets <builtins.set>` and :class:`tuples <builtins.tuple>` of
96
#: :class:`TaskGraphs <TaskGraph>`. A set of :class:`TaskGraphs
97
#: <TaskGraph>` means that they are unordered and can be
98
#: executed in parallel. A :class:`tuple` specifies an implicit ordering so
99
#: a :class:`tuple` of :class:`TaskGraphs <TaskGraph>` will be executed
100
#: sequentially in the given order.
101
TaskGraph = Union[Task, Set["TaskGraph"], Tuple["TaskGraph", ...]]
102
103
104
@dataclass
105
class Tasks(dict):
106
    first: Set[Task]
107
    last: Set[Task]
108
    graph: TaskGraph = ()
109
110
    def __init__(self, graph: TaskGraph):
111
        """Connect multiple tasks into a potentially complex graph.
112
113
        Parses a :class:`TaskGraph` into a :class:`Tasks` object.
114
        """
115
        if isinstance(graph, Callable):
116
            graph = PythonOperator(
117
                task_id=(
118
                    f"{graph.__module__.replace('egon.data.datasets.', '')}."
119
                    if "egon.data.datasets." in graph.__module__
120
                    else ""
121
                )
122
                + graph.__name__.replace("_", "-"),
123
                python_callable=graph,
124
            )
125
        self.graph = graph
126
        if isinstance(graph, Operator):
127
            self.first = {graph}
128
            self.last = {graph}
129
            self[graph.task_id] = graph
130
        elif isinstance(graph, abc.Sized) and len(graph) == 0:
131
            self.first = {}
132
            self.last = {}
133
        elif isinstance(graph, abc.Set):
134
            results = [Tasks(subtasks) for subtasks in graph]
135
            self.first = {task for result in results for task in result.first}
136
            self.last = {task for result in results for task in result.last}
137
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
138
        elif isinstance(graph, tuple):
139
            results = [Tasks(subtasks) for subtasks in graph]
140
            for (left, right) in zip(results[:-1], results[1:]):
141
                for last in left.last:
142
                    for first in right.first:
143
                        last.set_downstream(first)
144
            self.first = results[0].first
145
            self.last = results[-1].last
146
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
147
        else:
148
            raise (
149
                TypeError(
150
                    "`egon.data.datasets.Tasks` got an argument of type:\n\n"
151
                    f"  {type(graph)}\n\n"
152
                    "where only `Task`s, `Set`s and `Tuple`s are allowed."
153
                )
154
            )
155
156
157
@dataclass
158
class Dataset:
159
    #: The name of the Dataset
160
    name: str
161
    #: The :class:`Dataset`'s version. Can be anything from a simple
162
    #: semantic versioning string like "2.1.3", to a more complex
163
    #: string, like for example "2021-01-01.schleswig-holstein.0" for
164
    #: OpenStreetMap data.
165
    #: Note that the latter encodes the :class:`Dataset`'s date, region
166
    #: and a sequential number in case the data changes without the date
167
    #: or region changing, for example due to implementation changes.
168
    version: str
169
    #: The first task(s) of this :class:`Dataset` will be marked as
170
    #: downstream of any of the listed dependencies. In case of bare
171
    #: :class:`Task`, a direct link will be created whereas for a
172
    #: :class:`Dataset` the link will be made to all of its last tasks.
173
    dependencies: Iterable[Union[Dataset, Task]] = ()
174
    #: The tasks of this :class:`Dataset`. A :class:`TaskGraph` will
175
    #: automatically be converted to :class:`Tasks`.
176
    tasks: Union[Tasks, TaskGraph] = ()
177
178
    def check_version(self, after_execution=()):
179
        def skip_task(task, *xs, **ks):
180
            with db.session_scope() as session:
181
                datasets = session.query(Model).filter_by(name=self.name).all()
182
                if self.version in [
183
                    ds.version for ds in datasets
184
                ] and not re.search(r"\.dev$", self.version):
185
                    logger.info(
186
                        f"Dataset '{self.name}' version '{self.version}'"
187
                        f" already executed. Skipping."
188
                    )
189
                else:
190
                    for ds in datasets:
191
                        session.delete(ds)
192
                    result = super(type(task), task).execute(*xs, **ks)
193
                    for function in after_execution:
194
                        function(session)
195
                    return result
196
197
        return skip_task
198
199
    def update(self, session):
200
        dataset = Model(name=self.name, version=self.version)
201
        dependencies = (
202
            session.query(Model)
203
            .filter(
204
                tuple_(Model.name, Model.version).in_(
205
                    [
206
                        (dataset.name, dataset.version)
207
                        for dependency in self.dependencies
208
                        if isinstance(dependency, Dataset)
209
                        or hasattr(dependency, "dataset")
210
                        for dataset in [
211
                            dependency.dataset
212
                            if isinstance(dependency, Operator)
213
                            else dependency
214
                        ]
215
                    ]
216
                )
217
            )
218
            .all()
219
        )
220
        dataset.dependencies = dependencies
221
        session.add(dataset)
222
223
    def __post_init__(self):
224
        self.dependencies = list(self.dependencies)
225
        if not isinstance(self.tasks, Tasks):
226
            self.tasks = Tasks(self.tasks)
227
        if len(self.tasks.last) > 1:
228
            # Explicitly create single final task, because we can't know
229
            # which of the multiple tasks finishes last.
230
            update_version = PythonOperator(
231
                task_id=f"update-{self.name}-version",
232
                # Do nothing, because updating will be added later.
233
                python_callable=lambda *xs, **ks: None,
234
            )
235
            self.tasks = Tasks((self.tasks.graph, update_version))
236
        # Due to the `if`-block above, there'll now always be exactly
237
        # one task in `self.tasks.last` which the next line just
238
        # selects.
239
        last = list(self.tasks.last)[0]
240
        for task in self.tasks.values():
241
            task.dataset = self
242
            cls = task.__class__
243
            versioned = type(
244
                f"{self.name[0].upper()}{self.name[1:]} (versioned)",
245
                (cls,),
246
                {
247
                    "execute": self.check_version(
248
                        after_execution=[self.update] if task is last else []
249
                    )
250
                },
251
            )
252
            task.__class__ = versioned
253
254
        predecessors = [
255
            task
256
            for dataset in self.dependencies
257
            if isinstance(dataset, Dataset)
258
            for task in dataset.tasks.last
259
        ] + [task for task in self.dependencies if isinstance(task, Operator)]
260
        for p in predecessors:
261
            for first in self.tasks.first:
262
                p.set_downstream(first)
263
264
    def insert_into(self, dag: DAG):
265
        for task in self.tasks.values():
266
            for attribute in dag.default_args:
267
                if getattr(task, attribute) is None:
268
                    setattr(task, attribute, dag.default_args[attribute])
269
        dag.add_tasks(self.tasks.values())
270