Passed
Pull Request — dev (#1375)
by
unknown
02:12
created

data.datasets.Dataset.__post_init__()   F

Complexity

Conditions 16

Size

Total Lines 81
Code Lines 55

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 55
dl 0
loc 81
rs 2.4
c 0
b 0
f 0
cc 16
nop 1

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like data.datasets.Dataset.__post_init__() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""The API for configuring datasets."""
2
3
from __future__ import annotations
4
5
from collections import abc
6
from dataclasses import dataclass, field
7
from functools import partial, reduce, update_wrapper
8
from typing import Callable, Iterable, Set, Tuple, Union
9
import re
10
11
from airflow.models.baseoperator import BaseOperator as Operator
12
from airflow.operators.python import PythonOperator
13
from sqlalchemy import Column, ForeignKey, Integer, String, Table, orm, tuple_
14
from sqlalchemy.ext.declarative import declarative_base
15
from typing import Dict, List
16
from egon.data.validation_utils import create_validation_tasks
17
18
from egon.data import config, db, logger
19
20
try:
21
      from egon_validation.rules.base import Rule
22
except ImportError:
23
      Rule = None  # Type hint only
24
25
26
Base = declarative_base()
27
SCHEMA = "metadata"
28
29
30
def wrapped_partial(func, *args, **kwargs):
31
    """Like :func:`functools.partial`, but preserves the original function's
32
    name and docstring. Also allows to add a postfix to the function's name.
33
    """
34
    postfix = kwargs.pop("postfix", None)
35
    partial_func = partial(func, *args, **kwargs)
36
    update_wrapper(partial_func, func)
37
    if postfix:
38
        partial_func.__name__ = f"{func.__name__}{postfix}"
39
    return partial_func
40
41
42
def setup():
43
    """Create the database structure for storing dataset information."""
44
    # TODO: Move this into a task generating the initial database structure.
45
    db.execute_sql(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA};")
46
    Model.__table__.create(bind=db.engine(), checkfirst=True)
47
    DependencyGraph.create(bind=db.engine(), checkfirst=True)
48
49
50
# TODO: Figure out how to use a mapped class as an association table.
51
#
52
# Trying it out, I ran into quite a few problems and didn't have time to do
53
# further research. The benefits are mostly just convenience, so it doesn't
54
# have a high priority. But I'd like to keep the code I started out with to
55
# have a starting point for me or anybody else trying again.
56
#
57
# class DependencyGraph(Base):
58
#     __tablename__ = "dependency_graph"
59
#     __table_args__ = {"schema": SCHEMA}
60
#     dependency_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
61
#     dependent_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
62
63
DependencyGraph = Table(
64
    "dependency_graph",
65
    Base.metadata,
66
    Column(
67
        "dependency_id",
68
        Integer,
69
        ForeignKey(f"{SCHEMA}.datasets.id"),
70
        primary_key=True,
71
    ),
72
    Column(
73
        "dependent_id",
74
        Integer,
75
        ForeignKey(f"{SCHEMA}.datasets.id"),
76
        primary_key=True,
77
    ),
78
    schema=SCHEMA,
79
)
80
81
82
class Model(Base):
83
    __tablename__ = "datasets"
84
    __table_args__ = {"schema": SCHEMA}
85
    id = Column(Integer, primary_key=True)
86
    name = Column(String, unique=True, nullable=False)
87
    version = Column(String, nullable=False)
88
    epoch = Column(Integer, default=0)
89
    scenarios = Column(String, nullable=False)
90
    dependencies = orm.relationship(
91
        "Model",
92
        secondary=DependencyGraph,
93
        primaryjoin=id == DependencyGraph.c.dependent_id,
94
        secondaryjoin=id == DependencyGraph.c.dependency_id,
95
        backref=orm.backref("dependents", cascade="all, delete"),
96
    )
97
98
99
#: A :class:`Task` is an Airflow :class:`Operator` or any
100
#: :class:`Callable <typing.Callable>` taking no arguments and returning
101
#: :obj:`None`. :class:`Callables <typing.Callable>` will be converted
102
#: to :class:`Operators <Operator>` by wrapping them in a
103
#: :class:`PythonOperator` and setting the :obj:`~PythonOperator.task_id`
104
#: to the :class:`Callable <typing.Callable>`'s
105
#: :obj:`~definition.__name__`, with underscores replaced with hyphens.
106
#: If the :class:`Callable <typing.Callable>`'s `__module__`__ attribute
107
#: contains the string :obj:`"egon.data.datasets."`, the
108
#: :obj:`~PythonOperator.task_id` is also prefixed with the module name,
109
#: followed by a dot and with :obj:`"egon.data.datasets."` removed.
110
#:
111
#: __ https://docs.python.org/3/reference/datamodel.html#index-34
112
Task = Union[Callable[[], None], Operator]
113
#: A graph of tasks is, in its simplest form, just a single node, i.e. a
114
#: single :class:`Task`. More complex graphs can be specified by nesting
115
#: :class:`sets <builtins.set>` and :class:`tuples <builtins.tuple>` of
116
#: :class:`TaskGraphs <TaskGraph>`. A set of :class:`TaskGraphs
117
#: <TaskGraph>` means that they are unordered and can be
118
#: executed in parallel. A :class:`tuple` specifies an implicit ordering so
119
#: a :class:`tuple` of :class:`TaskGraphs <TaskGraph>` will be executed
120
#: sequentially in the given order.
121
TaskGraph = Union[Task, Set["TaskGraph"], Tuple["TaskGraph", ...]]
122
#: A type alias to help specifying that something can be an explicit
123
#: :class:`Tasks_` object or a :class:`TaskGraph`, i.e. something that
124
#: can be converted to :class:`Tasks_`.
125
Tasks = Union["Tasks_", TaskGraph]
126
127
128
def prefix(o):
129
    module = o.__module__
130
    parent = f"{__name__}."
131
    return f"{module.replace(parent, '')}." if parent in module else ""
132
133
134
@dataclass
135
class Tasks_(dict):
136
    first: Set[Task]
137
    last: Set[Task]
138
    graph: TaskGraph = ()
139
140
    def __init__(self, graph: TaskGraph):
141
        """Connect multiple tasks into a potentially complex graph.
142
143
        Parses a :class:`TaskGraph` into a :class:`Tasks_` object.
144
        """
145
        if isinstance(graph, Callable):
146
            graph = PythonOperator(
147
                task_id=f"{prefix(graph)}{graph.__name__.replace('_', '-')}",
148
                python_callable=graph,
149
            )
150
        self.graph = graph
151
        if isinstance(graph, Operator):
152
            self.first = {graph}
153
            self.last = {graph}
154
            self[graph.task_id] = graph
155
        elif isinstance(graph, abc.Sized) and len(graph) == 0:
156
            self.first = {}
157
            self.last = {}
158
        elif isinstance(graph, abc.Set):
159
            results = [Tasks_(subtasks) for subtasks in graph]
160
            self.first = {task for result in results for task in result.first}
161
            self.last = {task for result in results for task in result.last}
162
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
163
            self.graph = set(tasks.graph for tasks in results)
164
        elif isinstance(graph, tuple):
165
            results = [Tasks_(subtasks) for subtasks in graph]
166
            for left, right in zip(results[:-1], results[1:]):
167
                for last in left.last:
168
                    for first in right.first:
169
                        last.set_downstream(first)
170
            self.first = results[0].first
171
            self.last = results[-1].last
172
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
173
            self.graph = tuple(tasks.graph for tasks in results)
174
        else:
175
            raise (
176
                TypeError(
177
                    "`egon.data.datasets.Tasks_` got an argument of type:\n\n"
178
                    f"  {type(graph)}\n\n"
179
                    "where only `Task`s, `Set`s and `Tuple`s are allowed."
180
                )
181
            )
182
183
184
#: A dataset can depend on other datasets or the tasks of other datasets.
185
Dependencies = Iterable[Union["Dataset", Task]]
186
187
188
@dataclass
189
class Dataset:
190
    #: The name of the Dataset
191
    name: str
192
    #: The :class:`Dataset`'s version. Can be anything from a simple
193
    #: semantic versioning string like "2.1.3", to a more complex
194
    #: string, like for example "2021-01-01.schleswig-holstein.0" for
195
    #: OpenStreetMap data.
196
    #: Note that the latter encodes the :class:`Dataset`'s date, region
197
    #: and a sequential number in case the data changes without the date
198
    #: or region changing, for example due to implementation changes.
199
    version: str
200
    #: The first task(s) of this :class:`Dataset` will be marked as
201
    #: downstream of any of the listed dependencies. In case of bare
202
    #: :class:`Task`, a direct link will be created whereas for a
203
    #: :class:`Dataset` the link will be made to all of its last tasks.
204
    dependencies: Dependencies = ()
205
    #: The tasks of this :class:`Dataset`. A :class:`TaskGraph` will
206
    #: automatically be converted to :class:`Tasks_`.
207
    tasks: Tasks = ()
208
    validation: Dict[str, List] = field(default_factory=dict)
209
    validation_on_failure: str = "continue"
210
211
    def check_version(self, after_execution=()):
212
        scenario_names = config.settings()["egon-data"]["--scenarios"]
213
214
        def skip_task(task, *xs, **ks):
215
            with db.session_scope() as session:
216
                datasets = session.query(Model).filter_by(name=self.name).all()
217
                if (
218
                    self.version in [ds.version for ds in datasets]
219
                    and all(
220
                        scenario_names
221
                        == ds.scenarios.replace("{", "")
222
                        .replace("}", "")
223
                        .split(",")
224
                        for ds in datasets
225
                    )
226
                    and not re.search(r"\.dev$", self.version)
227
                ):
228
                    logger.info(
229
                        f"Dataset '{self.name}' version '{self.version}'"
230
                        f" scenarios {scenario_names}"
231
                        f" already executed. Skipping."
232
                    )
233
                else:
234
                    for ds in datasets:
235
                        session.delete(ds)
236
                    result = super(type(task), task).execute(*xs, **ks)
237
                    for function in after_execution:
238
                        function(session)
239
                    return result
240
241
        return skip_task
242
243
    def update(self, session):
244
        dataset = Model(
245
            name=self.name,
246
            version=self.version,
247
            scenarios=config.settings()["egon-data"]["--scenarios"],
248
        )
249
        dependencies = (
250
            session.query(Model)
251
            .filter(
252
                tuple_(Model.name, Model.version).in_(
253
                    [
254
                        (dataset.name, dataset.version)
255
                        for dependency in self.dependencies
256
                        if isinstance(dependency, Dataset)
257
                        or hasattr(dependency, "dataset")
258
                        for dataset in [
259
                            (
260
                                dependency.dataset
261
                                if isinstance(dependency, Operator)
262
                                else dependency
263
                            )
264
                        ]
265
                    ]
266
                )
267
            )
268
            .all()
269
        )
270
        dataset.dependencies = dependencies
271
        session.add(dataset)
272
273
    def __post_init__(self):
274
        self.dependencies = list(self.dependencies)
275
        if not isinstance(self.tasks, Tasks_):
276
            self.tasks = Tasks_(self.tasks)
277
            # Process validation configuration
278
        if self.validation:
279
            validation_tasks = create_validation_tasks(
280
                validation_dict=self.validation,
281
                dataset_name=self.name,
282
                on_failure=self.validation_on_failure
283
            )
284
285
            # Append validation tasks to existing tasks
286
            if validation_tasks:
287
                task_list = list(self.tasks.graph if hasattr(self.tasks, 'graph') else self.tasks)
288
                task_list.extend(validation_tasks)
289
                self.tasks = Tasks_(tuple(task_list))
290
291
        if len(self.tasks.last) > 1:
292
            # Explicitly create single final task, because we can't know
293
            # which of the multiple tasks finishes last.
294
            name = prefix(self)
295
            name = f"{name if name else f'{self.__module__}.'}{self.name}."
296
            update_version = PythonOperator(
297
                task_id=f"{name}update-version",
298
                # Do nothing, because updating will be added later.
299
                python_callable=lambda *xs, **ks: None,
300
            )
301
            self.tasks = Tasks_((self.tasks.graph, update_version))
302
        # Due to the `if`-block above, there'll now always be exactly
303
        # one task in `self.tasks.last` which the next line just
304
        # selects.
305
        last = list(self.tasks.last)[0]
306
        for task in self.tasks.values():
307
            task.dataset = self
308
            cls = task.__class__
309
            versioned = type(
310
                f"{self.name[0].upper()}{self.name[1:]} (versioned)",
311
                (cls,),
312
                {
313
                    "execute": self.check_version(
314
                        after_execution=[self.update] if task is last else []
315
                    )
316
                },
317
            )
318
            task.__class__ = versioned
319
320
        predecessors = [
321
            task
322
            for dataset in self.dependencies
323
            if isinstance(dataset, Dataset)
324
            for task in dataset.tasks.last
325
        ] + [task for task in self.dependencies if isinstance(task, Operator)]
326
        for p in predecessors:
327
            for first in self.tasks.first:
328
                p.set_downstream(first)
329
330
        # Link validation tasks to run after data tasks
331
        if self.validation and validation_tasks:
0 ignored issues
show
introduced by
The variable validation_tasks does not seem to be defined in case self.validation on line 278 is False. Are you sure this can never be the case?
Loading history...
332
            # Get last non-validation tasks
333
            non_validation_task_ids = [
334
                task.task_id for task in self.tasks.values()
335
                if not any(task.task_id.endswith(f".validate.{name}") for name in self.validation.keys())
336
            ]
337
338
            last_data_tasks = [
339
                task for task in self.tasks.values()
340
                if task.task_id in non_validation_task_ids and task in self.tasks.last
341
            ]
342
343
            if not last_data_tasks:
344
                # Fallback to last non-validation task
345
                last_data_tasks = [
346
                                      task for task in self.tasks.values()
347
                                      if task.task_id in non_validation_task_ids
348
                                  ][-1:]
349
350
            # Link each validation task downstream of last data tasks
351
            for validation_task in validation_tasks:
352
                for last_task in last_data_tasks:
353
                    last_task.set_downstream(validation_task)
354
355