Passed
Push — dev ( 836244...247d11 )
by Stephan
01:23 queued 11s
created

data.datasets   A

Complexity

Total Complexity 33

Size/Duplication

Total Lines 265
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 157
dl 0
loc 265
rs 9.76
c 0
b 0
f 0
wmc 33

5 Methods

Rating   Name   Duplication   Size   Complexity  
A Dataset.insert_into() 0 6 4
A Dataset.check_version() 0 18 5
D Tasks.__init__() 0 41 13
A Dataset.update() 0 21 2
B Dataset.__post_init__() 0 40 8

1 Function

Rating   Name   Duplication   Size   Complexity  
A setup() 0 6 1
1
"""The API for configuring datasets."""
2
3
from __future__ import annotations
4
5
from collections import abc
6
from dataclasses import dataclass
7
from functools import reduce
8
from typing import Callable, Iterable, Set, Tuple, Union
9
10
from airflow import DAG
11
from airflow.operators import BaseOperator as Operator
12
from airflow.operators.python_operator import PythonOperator
13
from sqlalchemy import Column, ForeignKey, Integer, String, Table, orm, tuple_
14
from sqlalchemy.ext.declarative import declarative_base
15
16
from egon.data import db, logger
17
18
Base = declarative_base()
19
SCHEMA = "metadata"
20
21
22
def setup():
23
    """Create the database structure for storing dataset information."""
24
    # TODO: Move this into a task generating the initial database structure.
25
    db.execute_sql(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA};")
26
    Model.__table__.create(bind=db.engine(), checkfirst=True)
27
    DependencyGraph.create(bind=db.engine(), checkfirst=True)
28
29
30
# TODO: Figure out how to use a mapped class as an association table.
31
#
32
# Trying it out, I ran into quite a few problems and didn't have time to do
33
# further research. The benefits are mostly just convenience, so it doesn't
34
# have a high priority. But I'd like to keep the code I started out with to
35
# have a starting point for me or anybody else trying again.
36
#
37
# class DependencyGraph(Base):
38
#     __tablename__ = "dependency_graph"
39
#     __table_args__ = {"schema": SCHEMA}
40
#     dependency_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
41
#     dependent_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
42
43
DependencyGraph = Table(
44
    "dependency_graph",
45
    Base.metadata,
46
    Column(
47
        "dependency_id",
48
        Integer,
49
        ForeignKey(f"{SCHEMA}.datasets.id"),
50
        primary_key=True,
51
    ),
52
    Column(
53
        "dependent_id",
54
        Integer,
55
        ForeignKey(f"{SCHEMA}.datasets.id"),
56
        primary_key=True,
57
    ),
58
    schema=SCHEMA,
59
)
60
61
62
class Model(Base):
63
    __tablename__ = "datasets"
64
    __table_args__ = {"schema": SCHEMA}
65
    id = Column(Integer, primary_key=True)
66
    name = Column(String, unique=True, nullable=False)
67
    version = Column(String, nullable=False)
68
    epoch = Column(Integer, default=0)
69
    dependencies = orm.relationship(
70
        "Model",
71
        secondary=DependencyGraph,
72
        primaryjoin=id == DependencyGraph.c.dependent_id,
73
        secondaryjoin=id == DependencyGraph.c.dependency_id,
74
        backref=orm.backref("dependents", cascade="all, delete"),
75
    )
76
77
78
#: A :class:`Task` is an Airflow :class:`Operator` or any
79
#: :class:`Callable <typing.Callable>` taking no arguments and returning
80
#: :obj:`None`. :class:`Callables <typing.Callable>` will be converted
81
#: to :class:`Operators <Operator>` by wrapping them in a
82
#: :class:`PythonOperator` and setting the :obj:`~PythonOperator.task_id`
83
#: to the :class:`Callable <typing.Callable>`'s
84
#: :obj:`~definition.__name__`, with underscores replaced with hyphens.
85
#: If the :class:`Callable <typing.Callable>`'s `__module__`__ attribute
86
#: contains the string :obj:`"egon.data.datasets."`, the
87
#: :obj:`~PythonOperator.task_id` is also prefixed with the module name,
88
#: followed by a dot and with :obj:`"egon.data.datasets."` removed.
89
#:
90
#: __ https://docs.python.org/3/reference/datamodel.html#index-34
91
Task = Union[Callable[[], None], Operator]
92
#: A graph of tasks is, in its simplest form, just a single node, i.e. a
93
#: single :class:`Task`. More complex graphs can be specified by nesting
94
#: :class:`sets <builtins.set>` and :class:`tuples <builtins.tuple>` of
95
#: :class:`TaskGraphs <TaskGraph>`. A set of :class:`TaskGraphs
96
#: <TaskGraph>` means that they are unordered and can be
97
#: executed in parallel. A :class:`tuple` specifies an implicit ordering so
98
#: a :class:`tuple` of :class:`TaskGraphs <TaskGraph>` will be executed
99
#: sequentially in the given order.
100
TaskGraph = Union[Task, Set["TaskGraph"], Tuple["TaskGraph", ...]]
101
102
103
@dataclass
104
class Tasks(dict):
105
    first: Set[Task]
106
    last: Set[Task]
107
    graph: TaskGraph = ()
108
109
    def __init__(self, graph: TaskGraph):
110
        """Connect multiple tasks into a potentially complex graph.
111
112
        Parses a :class:`TaskGraph` into a :class:`Tasks` object.
113
        """
114
        if isinstance(graph, Callable):
115
            graph = PythonOperator(
116
                task_id=(
117
                    f"{graph.__module__.replace('egon.data.datasets.', '')}."
118
                    if "egon.data.datasets." in graph.__module__
119
                    else ""
120
                )
121
                + graph.__name__.replace("_", "-"),
122
                python_callable=graph,
123
            )
124
        self.graph = graph
125
        if isinstance(graph, Operator):
126
            self.first = {graph}
127
            self.last = {graph}
128
            self[graph.task_id] = graph
129
        elif isinstance(graph, abc.Sized) and len(graph) == 0:
130
            self.first = {}
131
            self.last = {}
132
        elif isinstance(graph, abc.Set):
133
            results = [Tasks(subtasks) for subtasks in graph]
134
            self.first = {task for result in results for task in result.first}
135
            self.last = {task for result in results for task in result.last}
136
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
137
        elif isinstance(graph, tuple):
138
            results = [Tasks(subtasks) for subtasks in graph]
139
            for (left, right) in zip(results[:-1], results[1:]):
140
                for last in left.last:
141
                    for first in right.first:
142
                        last.set_downstream(first)
143
            self.first = results[0].first
144
            self.last = results[-1].last
145
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
146
        else:
147
            raise (
148
                TypeError(
149
                    "`egon.data.datasets.Tasks` got an argument of type:\n\n"
150
                    f"  {type(graph)}\n\n"
151
                    "where only `Task`s, `Set`s and `Tuple`s are allowed."
152
                )
153
            )
154
155
156
@dataclass
157
class Dataset:
158
    #: The name of the Dataset
159
    name: str
160
    #: The :class:`Dataset`'s version. Can be anything from a simple
161
    #: semantic versioning string like "2.1.3", to a more complex
162
    #: string, like for example "2021-01-01.schleswig-holstein.0" for
163
    #: OpenStreetMap data.
164
    #: Note that the latter encodes the :class:`Dataset`'s date, region
165
    #: and a sequential number in case the data changes without the date
166
    #: or region changing, for example due to implementation changes.
167
    version: str
168
    #: The first task(s) of this :class:`Dataset` will be marked as
169
    #: downstream of any of the listed dependencies. In case of bare
170
    #: :class:`Task`, a direct link will be created whereas for a
171
    #: :class:`Dataset` the link will be made to all of its last tasks.
172
    dependencies: Iterable[Union[Dataset, Task]] = ()
173
    #: The tasks of this :class:`Dataset`. A :class:`TaskGraph` will
174
    #: automatically be converted to :class:`Tasks`.
175
    tasks: Union[Tasks, TaskGraph] = ()
176
177
    def check_version(self, after_execution=()):
178
        def skip_task(task, *xs, **ks):
179
            with db.session_scope() as session:
180
                datasets = session.query(Model).filter_by(name=self.name).all()
181
                if self.version in [ds.version for ds in datasets]:
182
                    logger.info(
183
                        f"Dataset '{self.name}' version '{self.version}'"
184
                        f" already executed. Skipping."
185
                    )
186
                else:
187
                    for ds in datasets:
188
                        session.delete(ds)
189
                    result = super(type(task), task).execute(*xs, **ks)
190
                    for function in after_execution:
191
                        function(session)
192
                    return result
193
194
        return skip_task
195
196
    def update(self, session):
197
        dataset = Model(name=self.name, version=self.version)
198
        dependencies = (
199
            session.query(Model)
200
            .filter(
201
                tuple_(Model.name, Model.version).in_(
202
                    [
203
                        (dataset.name, dataset.version)
204
                        for dependency in self.dependencies
205
                        for dataset in [
206
                            dependency.dataset
207
                            if isinstance(dependency, Operator)
208
                            else dependency
209
                        ]
210
                    ]
211
                )
212
            )
213
            .all()
214
        )
215
        dataset.dependencies = dependencies
216
        session.add(dataset)
217
218
    def __post_init__(self):
219
        self.dependencies = list(self.dependencies)
220
        if not isinstance(self.tasks, Tasks):
221
            self.tasks = Tasks(self.tasks)
222
        if len(self.tasks.last) > 1:
223
            # Explicitly create single final task, because we can't know
224
            # which of the multiple tasks finishes last.
225
            update_version = PythonOperator(
226
                task_id=f"update-{self.name}-version",
227
                # Do nothing, because updating will be added later.
228
                python_callable=lambda *xs, **ks: None,
229
            )
230
            self.tasks = Tasks((self.tasks.graph, update_version))
231
        # Due to the `if`-block above, there'll now always be exactly
232
        # one task in `self.tasks.last` which the next line just
233
        # selects.
234
        last = list(self.tasks.last)[0]
235
        for task in self.tasks.values():
236
            task.dataset = self
237
            cls = task.__class__
238
            versioned = type(
239
                f"{self.name[0].upper()}{self.name[1:]} (versioned)",
240
                (cls,),
241
                {
242
                    "execute": self.check_version(
243
                        after_execution=[self.update] if task is last else []
244
                    )
245
                },
246
            )
247
            task.__class__ = versioned
248
249
        predecessors = [
250
            task
251
            for dataset in self.dependencies
252
            if isinstance(dataset, Dataset)
253
            for task in dataset.tasks.last
254
        ] + [task for task in self.dependencies if isinstance(task, Operator)]
255
        for p in predecessors:
256
            for first in self.tasks.first:
257
                p.set_downstream(first)
258
259
    def insert_into(self, dag: DAG):
260
        for task in self.tasks.values():
261
            for attribute in dag.default_args:
262
                if getattr(task, attribute) is None:
263
                    setattr(task, attribute, dag.default_args[attribute])
264
        dag.add_tasks(self.tasks.values())
265