Passed
Pull Request — dev (#1138)
by
unknown
02:19
created

data.datasets.Tasks_.__init__()   D

Complexity

Conditions 12

Size

Total Lines 38
Code Lines 32

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 32
dl 0
loc 38
rs 4.8
c 0
b 0
f 0
cc 12
nop 2

How to fix   Complexity   

Complexity

Complex classes like data.datasets.Tasks_.__init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""The API for configuring datasets."""
2
3
from __future__ import annotations
4
5
from collections import abc
6
from dataclasses import dataclass
7
from functools import reduce
8
from typing import Callable, Iterable, Set, Tuple, Union
9
import re
10
11
from airflow.operators import BaseOperator as Operator
12
from airflow.operators.python_operator import PythonOperator
13
from sqlalchemy import Column, ForeignKey, Integer, String, Table, orm, tuple_
14
from sqlalchemy.ext.declarative import declarative_base
15
16
from egon.data import db, logger
17
18
Base = declarative_base()
19
SCHEMA = "metadata"
20
21
22
def setup():
23
    """Create the database structure for storing dataset information."""
24
    # TODO: Move this into a task generating the initial database structure.
25
    db.execute_sql(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA};")
26
    Model.__table__.create(bind=db.engine(), checkfirst=True)
27
    DependencyGraph.create(bind=db.engine(), checkfirst=True)
28
29
30
# TODO: Figure out how to use a mapped class as an association table.
31
#
32
# Trying it out, I ran into quite a few problems and didn't have time to do
33
# further research. The benefits are mostly just convenience, so it doesn't
34
# have a high priority. But I'd like to keep the code I started out with to
35
# have a starting point for me or anybody else trying again.
36
#
37
# class DependencyGraph(Base):
38
#     __tablename__ = "dependency_graph"
39
#     __table_args__ = {"schema": SCHEMA}
40
#     dependency_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
41
#     dependent_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
42
43
DependencyGraph = Table(
44
    "dependency_graph",
45
    Base.metadata,
46
    Column(
47
        "dependency_id",
48
        Integer,
49
        ForeignKey(f"{SCHEMA}.datasets.id"),
50
        primary_key=True,
51
    ),
52
    Column(
53
        "dependent_id",
54
        Integer,
55
        ForeignKey(f"{SCHEMA}.datasets.id"),
56
        primary_key=True,
57
    ),
58
    schema=SCHEMA,
59
)
60
61
62
class Model(Base):
63
    __tablename__ = "datasets"
64
    __table_args__ = {"schema": SCHEMA}
65
    id = Column(Integer, primary_key=True)
66
    name = Column(String, unique=True, nullable=False)
67
    version = Column(String, nullable=False)
68
    epoch = Column(Integer, default=0)
69
    dependencies = orm.relationship(
70
        "Model",
71
        secondary=DependencyGraph,
72
        primaryjoin=id == DependencyGraph.c.dependent_id,
73
        secondaryjoin=id == DependencyGraph.c.dependency_id,
74
        backref=orm.backref("dependents", cascade="all, delete"),
75
    )
76
77
78
#: A :class:`Task` is an Airflow :class:`Operator` or any
79
#: :class:`Callable <typing.Callable>` taking no arguments and returning
80
#: :obj:`None`. :class:`Callables <typing.Callable>` will be converted
81
#: to :class:`Operators <Operator>` by wrapping them in a
82
#: :class:`PythonOperator` and setting the :obj:`~PythonOperator.task_id`
83
#: to the :class:`Callable <typing.Callable>`'s
84
#: :obj:`~definition.__name__`, with underscores replaced with hyphens.
85
#: If the :class:`Callable <typing.Callable>`'s `__module__`__ attribute
86
#: contains the string :obj:`"egon.data.datasets."`, the
87
#: :obj:`~PythonOperator.task_id` is also prefixed with the module name,
88
#: followed by a dot and with :obj:`"egon.data.datasets."` removed.
89
#:
90
#: __ https://docs.python.org/3/reference/datamodel.html#index-34
91
Task = Union[Callable[[], None], Operator]
92
#: A graph of tasks is, in its simplest form, just a single node, i.e. a
93
#: single :class:`Task`. More complex graphs can be specified by nesting
94
#: :class:`sets <builtins.set>` and :class:`tuples <builtins.tuple>` of
95
#: :class:`TaskGraphs <TaskGraph>`. A set of :class:`TaskGraphs
96
#: <TaskGraph>` means that they are unordered and can be
97
#: executed in parallel. A :class:`tuple` specifies an implicit ordering so
98
#: a :class:`tuple` of :class:`TaskGraphs <TaskGraph>` will be executed
99
#: sequentially in the given order.
100
TaskGraph = Union[Task, Set["TaskGraph"], Tuple["TaskGraph", ...]]
101
#: A type alias to help specifying that something can be an explicit
102
#: :class:`Tasks_` object or a :class:`TaskGraph`, i.e. something that
103
#: can be converted to :class:`Tasks_`.
104
Tasks = Union["Tasks_", TaskGraph]
105
106
107
def prefix(o):
108
    module = o.__module__
109
    parent = f"{__name__}."
110
    return f"{module.replace(parent, '')}." if parent in module else ""
111
112
113
@dataclass
114
class Tasks_(dict):
115
    first: Set[Task]
116
    last: Set[Task]
117
    graph: TaskGraph = ()
118
119
    def __init__(self, graph: TaskGraph):
120
        """Connect multiple tasks into a potentially complex graph.
121
122
        Parses a :class:`TaskGraph` into a :class:`Tasks_` object.
123
        """
124
        if isinstance(graph, Callable):
125
            graph = PythonOperator(
126
                task_id=f"{prefix(graph)}{graph.__name__.replace('_', '-')}",
127
                python_callable=graph,
128
            )
129
        self.graph = graph
130
        if isinstance(graph, Operator):
131
            self.first = {graph}
132
            self.last = {graph}
133
            self[graph.task_id] = graph
134
        elif isinstance(graph, abc.Sized) and len(graph) == 0:
135
            self.first = {}
136
            self.last = {}
137
        elif isinstance(graph, abc.Set):
138
            results = [Tasks_(subtasks) for subtasks in graph]
139
            self.first = {task for result in results for task in result.first}
140
            self.last = {task for result in results for task in result.last}
141
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
142
            self.graph = set(tasks.graph for tasks in results)
143
        elif isinstance(graph, tuple):
144
            results = [Tasks_(subtasks) for subtasks in graph]
145
            for left, right in zip(results[:-1], results[1:]):
146
                for last in left.last:
147
                    for first in right.first:
148
                        last.set_downstream(first)
149
            self.first = results[0].first
150
            self.last = results[-1].last
151
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
152
            self.graph = tuple(tasks.graph for tasks in results)
153
        else:
154
            raise (
155
                TypeError(
156
                    "`egon.data.datasets.Tasks_` got an argument of type:\n\n"
157
                    f"  {type(graph)}\n\n"
158
                    "where only `Task`s, `Set`s and `Tuple`s are allowed."
159
                )
160
            )
161
162
163
#: A dataset can depend on other datasets or the tasks of other datasets.
164
Dependencies = Iterable[Union["Dataset", Task]]
165
166
167
@dataclass
168
class Dataset:
169
    #: The name of the Dataset
170
    name: str
171
    #: The :class:`Dataset`'s version. Can be anything from a simple
172
    #: semantic versioning string like "2.1.3", to a more complex
173
    #: string, like for example "2021-01-01.schleswig-holstein.0" for
174
    #: OpenStreetMap data.
175
    #: Note that the latter encodes the :class:`Dataset`'s date, region
176
    #: and a sequential number in case the data changes without the date
177
    #: or region changing, for example due to implementation changes.
178
    version: str
179
    #: The first task(s) of this :class:`Dataset` will be marked as
180
    #: downstream of any of the listed dependencies. In case of bare
181
    #: :class:`Task`, a direct link will be created whereas for a
182
    #: :class:`Dataset` the link will be made to all of its last tasks.
183
    dependencies: Dependencies = ()
184
    #: The tasks of this :class:`Dataset`. A :class:`TaskGraph` will
185
    #: automatically be converted to :class:`Tasks_`.
186
    tasks: Tasks = ()
187
188
    def check_version(self, after_execution=()):
189
        def skip_task(task, *xs, **ks):
190
            with db.session_scope() as session:
191
                datasets = session.query(Model).filter_by(name=self.name).all()
192
                if self.version in [
193
                    ds.version for ds in datasets
194
                ] and not re.search(r"\.dev$", self.version):
195
                    logger.info(
196
                        f"Dataset '{self.name}' version '{self.version}'"
197
                        f" already executed. Skipping."
198
                    )
199
                else:
200
                    for ds in datasets:
201
                        session.delete(ds)
202
                    result = super(type(task), task).execute(*xs, **ks)
203
                    for function in after_execution:
204
                        function(session)
205
                    return result
206
207
        return skip_task
208
209
    def update(self, session):
210
        dataset = Model(name=self.name, version=self.version)
211
        dependencies = (
212
            session.query(Model)
213
            .filter(
214
                tuple_(Model.name, Model.version).in_(
215
                    [
216
                        (dataset.name, dataset.version)
217
                        for dependency in self.dependencies
218
                        if isinstance(dependency, Dataset)
219
                        or hasattr(dependency, "dataset")
220
                        for dataset in [
221
                            dependency.dataset
222
                            if isinstance(dependency, Operator)
223
                            else dependency
224
                        ]
225
                    ]
226
                )
227
            )
228
            .all()
229
        )
230
        dataset.dependencies = dependencies
231
        session.add(dataset)
232
233
    def __post_init__(self):
234
        self.dependencies = list(self.dependencies)
235
        if not isinstance(self.tasks, Tasks_):
236
            self.tasks = Tasks_(self.tasks)
237
        if len(self.tasks.last) > 1:
238
            # Explicitly create single final task, because we can't know
239
            # which of the multiple tasks finishes last.
240
            name = prefix(self)
241
            name = f"{name if name else f'{self.__module__}.'}{self.name}."
242
            update_version = PythonOperator(
243
                task_id=f"{name}update-version",
244
                # Do nothing, because updating will be added later.
245
                python_callable=lambda *xs, **ks: None,
246
            )
247
            self.tasks = Tasks_((self.tasks.graph, update_version))
248
        # Due to the `if`-block above, there'll now always be exactly
249
        # one task in `self.tasks.last` which the next line just
250
        # selects.
251
        last = list(self.tasks.last)[0]
252
        for task in self.tasks.values():
253
            task.dataset = self
254
            cls = task.__class__
255
            versioned = type(
256
                f"{self.name[0].upper()}{self.name[1:]} (versioned)",
257
                (cls,),
258
                {
259
                    "execute": self.check_version(
260
                        after_execution=[self.update] if task is last else []
261
                    )
262
                },
263
            )
264
            task.__class__ = versioned
265
266
        predecessors = [
267
            task
268
            for dataset in self.dependencies
269
            if isinstance(dataset, Dataset)
270
            for task in dataset.tasks.last
271
        ] + [task for task in self.dependencies if isinstance(task, Operator)]
272
        for p in predecessors:
273
            for first in self.tasks.first:
274
                p.set_downstream(first)
275