Passed
Pull Request — dev (#234)
by Stephan
01:10
created

data.datasets   A

Complexity

Total Complexity 14

Size/Duplication

Total Lines 129
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 90
dl 0
loc 129
rs 10
c 0
b 0
f 0
wmc 14

2 Functions

Rating   Name   Duplication   Size   Complexity  
C connect() 0 32 9
A setup() 0 6 1

2 Methods

Rating   Name   Duplication   Size   Complexity  
A Dataset.insert_into() 0 2 1
A Dataset.__post_init__() 0 6 3
1
"""The API for configuring datasets."""
2
3
from dataclasses import dataclass
4
from typing import List, Set, Tuple, Union
5
import collections.abc as cabc
6
7
from airflow import DAG
8
from airflow.operators import BaseOperator as Operator
9
from airflow.utils.dates import days_ago
10
from sqlalchemy import Column, ForeignKey, Integer, String, Table, orm
11
from sqlalchemy.ext.declarative import declarative_base
12
13
from egon.data import db
14
15
Base = declarative_base()
16
SCHEMA = "metadata"
17
DEFAULTS = {"start_date": days_ago(1)}
18
19
20
def setup():
21
    """Create the database structure for storing dataset information."""
22
    # TODO: Move this into a task generating the initial database structure.
23
    db.execute_sql(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA};")
24
    Model.__table__.create(bind=db.engine(), checkfirst=True)
25
    DependencyGraph.create(bind=db.engine(), checkfirst=True)
26
27
28
DependencyGraph = Table(
29
    "dependency_graph",
30
    Base.metadata,
31
    Column(
32
        "dependency_id",
33
        Integer,
34
        ForeignKey(f"{SCHEMA}.datasets.id"),
35
        primary_key=True,
36
    ),
37
    Column(
38
        "dependent_id",
39
        Integer,
40
        ForeignKey(f"{SCHEMA}.datasets.id"),
41
        primary_key=True,
42
    ),
43
    schema="metadata",
44
)
45
46
47
class Model(Base):
48
    __tablename__ = "datasets"
49
    __table_args__ = {"schema": SCHEMA}
50
    id = Column(Integer, primary_key=True)
51
    name = Column(String, unique=True, nullable=False)
52
    version = Column(String, nullable=False)
53
    epoch = Column(Integer, default=0)
54
    dependencies = orm.relationship(
55
        "Model",
56
        secondary=DependencyGraph,
57
        primaryjoin=id == DependencyGraph.c.dependent_id,
58
        secondaryjoin=id == DependencyGraph.c.dependency_id,
59
        backref="dependents",
60
    )
61
62
63
TaskGraph = Union[Operator, "ParallelTasks", "SequentialTasks"]
64
ParallelTasks = Set[TaskGraph]
65
SequentialTasks = Tuple[TaskGraph, ...]
66
67
68
@dataclass
69
class Tasks:
70
    first: Set[Operator]
71
    last: Set[Operator]
72
    all: Set[Operator]
73
74
75
def connect(tasks: TaskGraph):
76
    """Connect multiple tasks into a potentially complex graph.
77
78
    As per the type, a task graph can be given as a single operator, a tuple
79
    of task graphs or a set of task graphs. A tuple will be executed in the
80
    specified order, whereas a set means that the tasks in the graph will be
81
    executed in parallel.
82
    """
83
    if isinstance(tasks, Operator):
84
        return Tasks(first={tasks}, last={tasks}, all={tasks})
85
    elif isinstance(tasks, cabc.Sized) and len(tasks) == 0:
86
        return Tasks(first={}, last={}, all={})
87
    elif isinstance(tasks, cabc.Set):
88
        results = [connect(subtasks) for subtasks in tasks]
89
        first = {task for result in results for task in result.first}
90
        last = {task for result in results for task in result.last}
91
        tasks = {task for result in results for task in result.all}
92
        return Tasks(first, last, tasks)
93
    elif isinstance(tasks, tuple):
94
        results = [connect(subtasks) for subtasks in tasks]
95
        for (left, right) in zip(results[:-1], results[1:]):
96
            for last in left.last:
97
                for first in right.first:
98
                    last.set_downstream(first)
99
        first = results[0].first
100
        last = results[-1].last
101
        tasks = {task for result in results for task in result.all}
102
        return Tasks(first, last, tasks)
103
    else:
104
        raise (
105
            TypeError(
106
                "`egon.data.datasets.connect` got an argument of type:\n\n"
107
                f"  {type(tasks)}\n\n"
108
                "where only `Operator`s, `Set`s and `Tuple`s are allowed."
109
            )
110
        )
111
112
113
@dataclass
114
class Dataset:
115
    name: str
116
    version: str
117
    dependencies: List["Dataset"]
118
    graph: TaskGraph
119
120
    def __post_init__(self):
121
        self.tasks = connect(self.graph)
122
        predecessors = [p for d in self.dependencies for p in d.tasks.last]
123
        for p in predecessors:
124
            for first in self.tasks.first:
125
                p.set_downstream(first)
126
127
    def insert_into(self, dag: DAG):
128
        dag.add_tasks(self.tasks.all)
129