Passed
Pull Request — dev (#1304)
by
unknown
01:59
created

data.datasets.Dataset.__init_subclass__()   A

Complexity

Conditions 3

Size

Total Lines 11
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 11
rs 9.95
c 0
b 0
f 0
cc 3
nop 1
1
"""The API for configuring datasets."""
2
3
from __future__ import annotations
4
5
from collections import abc
6
from dataclasses import dataclass, field
7
from functools import partial, reduce, update_wrapper
8
from typing import Callable, Dict, Iterable, Set, Tuple, Union
9
import re
10
11
import json
12
from pathlib import Path
13
from airflow.models.baseoperator import BaseOperator as Operator
14
from airflow.operators.python import PythonOperator
15
from sqlalchemy import Column, ForeignKey, Integer, String, Table, orm, tuple_
16
from sqlalchemy.ext.declarative import declarative_base
17
from sqlalchemy.dialects.postgresql import JSONB
18
19
from egon.data import config, db, logger
20
21
Base = declarative_base()
22
SCHEMA = "metadata"
23
24
25
def wrapped_partial(func, *args, **kwargs):
26
    """Like :func:`functools.partial`, but preserves the original function's
27
    name and docstring. Also allows to add a postfix to the function's name.
28
    """
29
    postfix = kwargs.pop("postfix", None)
30
    partial_func = partial(func, *args, **kwargs)
31
    update_wrapper(partial_func, func)
32
    if postfix:
33
        partial_func.__name__ = f"{func.__name__}{postfix}"
34
    return partial_func
35
36
37
def setup():
38
    """Create the database structure for storing dataset information."""
39
    # TODO: Move this into a task generating the initial database structure.
40
    db.execute_sql(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA};")
41
    Model.__table__.create(bind=db.engine(), checkfirst=True)
42
    DependencyGraph.create(bind=db.engine(), checkfirst=True)
43
44
45
# TODO: Figure out how to use a mapped class as an association table.
46
#
47
# Trying it out, I ran into quite a few problems and didn't have time to do
48
# further research. The benefits are mostly just convenience, so it doesn't
49
# have a high priority. But I'd like to keep the code I started out with to
50
# have a starting point for me or anybody else trying again.
51
#
52
# class DependencyGraph(Base):
53
#     __tablename__ = "dependency_graph"
54
#     __table_args__ = {"schema": SCHEMA}
55
#     dependency_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
56
#     dependent_id = Column(Integer, ForeignKey(Model.id), primary_key=True,)
57
58
DependencyGraph = Table(
59
    "dependency_graph",
60
    Base.metadata,
61
    Column(
62
        "dependency_id",
63
        Integer,
64
        ForeignKey(f"{SCHEMA}.datasets.id"),
65
        primary_key=True,
66
    ),
67
    Column(
68
        "dependent_id",
69
        Integer,
70
        ForeignKey(f"{SCHEMA}.datasets.id"),
71
        primary_key=True,
72
    ),
73
    schema=SCHEMA,
74
)
75
76
77
class Model(Base):
78
    __tablename__ = "datasets"
79
    __table_args__ = {"schema": SCHEMA}
80
    id = Column(Integer, primary_key=True)
81
    name = Column(String, unique=True, nullable=False)
82
    version = Column(String, nullable=False)
83
    epoch = Column(Integer, default=0)
84
    scenarios = Column(String, nullable=False)
85
    sources = Column(JSONB, nullable=True)
86
    targets = Column(JSONB, nullable=True)
87
88
    dependencies = orm.relationship(
89
        "Model",
90
        secondary=DependencyGraph,
91
        primaryjoin=id == DependencyGraph.c.dependent_id,
92
        secondaryjoin=id == DependencyGraph.c.dependency_id,
93
        backref=orm.backref("dependents", cascade="all, delete"),
94
    )
95
96
97
@dataclass
98
class DatasetSources:
99
    tables: Dict[str, str] = field(default_factory=dict)
100
    files: Dict[str, str] = field(default_factory=dict)
101
    urls: Dict[str, str] = field(default_factory=dict)
102
103
    def empty(self):
104
        return not (self.tables or self.files or self.urls)
105
106
    def get_table_schema(self, key: str) -> str:
107
        """Returns the schema of the table identified by key."""
108
        try:
109
            return self.tables[key].split(".", 1)[0]
110
        except (KeyError, AttributeError, IndexError):
111
            raise ValueError(f"Invalid table reference: {self.tables.get(key)}")
112
113
    def get_table_name(self, key: str) -> str:
114
        """Returns the table name of the table identified by key."""
115
        try:
116
            return self.tables[key].split(".", 1)[1]
117
        except (KeyError, AttributeError, IndexError):
118
            raise ValueError(f"Invalid table reference: {self.tables.get(key)}")
119
120
    def to_dict(self):
121
        return {
122
            "tables": self.tables,
123
            "urls": self.urls,
124
            "files": self.files,
125
        }
126
127
    @classmethod
128
    def from_dict(cls, data):
129
        return cls(
130
            tables=data.get("tables", {}),
131
            urls=data.get("urls", {}),
132
            files=data.get("files", {}),
133
        )
134
135
@dataclass
136
class DatasetTargets:
137
    tables: Dict[str, str] = field(default_factory=dict)
138
    files: Dict[str, str] = field(default_factory=dict)
139
140
    def empty(self):
141
        return not (self.tables or self.files)
142
143
    def get_table_schema(self, key: str) -> str:
144
        """Returns the schema of the table identified by key."""
145
        try:
146
            return self.tables[key].split(".", 1)[0]
147
        except (KeyError, AttributeError, IndexError):
148
            raise ValueError(f"Invalid table reference: {self.tables.get(key)}")
149
150
    def get_table_name(self, key: str) -> str:
151
        """Returns the table name of the table identified by key."""
152
        try:
153
            return self.tables[key].split(".", 1)[1]
154
        except (KeyError, AttributeError, IndexError):
155
            raise ValueError(f"Invalid table reference: {self.tables.get(key)}")
156
157
    def to_dict(self):
158
        return {
159
            "tables": self.tables,
160
            "files": self.files,
161
        }
162
163
    def from_dict(cls, data):
164
        return cls(
165
            tables=data.get("tables", {}),
166
            files=data.get("files", {}),
167
        )
168
169
#: A :class:`Task` is an Airflow :class:`Operator` or any
170
#: :class:`Callable <typing.Callable>` taking no arguments and returning
171
#: :obj:`None`. :class:`Callables <typing.Callable>` will be converted
172
#: to :class:`Operators <Operator>` by wrapping them in a
173
#: :class:`PythonOperator` and setting the :obj:`~PythonOperator.task_id`
174
#: to the :class:`Callable <typing.Callable>`'s
175
#: :obj:`~definition.__name__`, with underscores replaced with hyphens.
176
#: If the :class:`Callable <typing.Callable>`'s `__module__`__ attribute
177
#: contains the string :obj:`"egon.data.datasets."`, the
178
#: :obj:`~PythonOperator.task_id` is also prefixed with the module name,
179
#: followed by a dot and with :obj:`"egon.data.datasets."` removed.
180
#:
181
#: __ https://docs.python.org/3/reference/datamodel.html#index-34
182
Task = Union[Callable[[], None], Operator]
183
#: A graph of tasks is, in its simplest form, just a single node, i.e. a
184
#: single :class:`Task`. More complex graphs can be specified by nesting
185
#: :class:`sets <builtins.set>` and :class:`tuples <builtins.tuple>` of
186
#: :class:`TaskGraphs <TaskGraph>`. A set of :class:`TaskGraphs
187
#: <TaskGraph>` means that they are unordered and can be
188
#: executed in parallel. A :class:`tuple` specifies an implicit ordering so
189
#: a :class:`tuple` of :class:`TaskGraphs <TaskGraph>` will be executed
190
#: sequentially in the given order.
191
TaskGraph = Union[Task, Set["TaskGraph"], Tuple["TaskGraph", ...]]
192
#: A type alias to help specifying that something can be an explicit
193
#: :class:`Tasks_` object or a :class:`TaskGraph`, i.e. something that
194
#: can be converted to :class:`Tasks_`.
195
Tasks = Union["Tasks_", TaskGraph]
196
197
198
def prefix(o):
199
    module = o.__module__
200
    parent = f"{__name__}."
201
    return f"{module.replace(parent, '')}." if parent in module else ""
202
203
204
@dataclass
205
class Tasks_(dict):
206
    first: Set[Task]
207
    last: Set[Task]
208
    graph: TaskGraph = ()
209
210
    def __init__(self, graph: TaskGraph):
211
        """Connect multiple tasks into a potentially complex graph.
212
213
        Parses a :class:`TaskGraph` into a :class:`Tasks_` object.
214
        """
215
        if isinstance(graph, Callable):
216
            graph = PythonOperator(
217
                task_id=f"{prefix(graph)}{graph.__name__.replace('_', '-')}",
218
                python_callable=graph,
219
            )
220
        self.graph = graph
221
        if isinstance(graph, Operator):
222
            self.first = {graph}
223
            self.last = {graph}
224
            self[graph.task_id] = graph
225
        elif isinstance(graph, abc.Sized) and len(graph) == 0:
226
            self.first = {}
227
            self.last = {}
228
        elif isinstance(graph, abc.Set):
229
            results = [Tasks_(subtasks) for subtasks in graph]
230
            self.first = {task for result in results for task in result.first}
231
            self.last = {task for result in results for task in result.last}
232
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
233
            self.graph = set(tasks.graph for tasks in results)
234
        elif isinstance(graph, tuple):
235
            results = [Tasks_(subtasks) for subtasks in graph]
236
            for left, right in zip(results[:-1], results[1:]):
237
                for last in left.last:
238
                    for first in right.first:
239
                        last.set_downstream(first)
240
            self.first = results[0].first
241
            self.last = results[-1].last
242
            self.update(reduce(lambda d1, d2: dict(d1, **d2), results, {}))
243
            self.graph = tuple(tasks.graph for tasks in results)
244
        else:
245
            raise (
246
                TypeError(
247
                    "`egon.data.datasets.Tasks_` got an argument of type:\n\n"
248
                    f"  {type(graph)}\n\n"
249
                    "where only `Task`s, `Set`s and `Tuple`s are allowed."
250
                )
251
            )
252
253
254
#: A dataset can depend on other datasets or the tasks of other datasets.
255
Dependencies = Iterable[Union["Dataset", Task]]
256
257
258
@dataclass
259
class Dataset:
260
    #: The name of the Dataset
261
    name: str
262
    #: The :class:`Dataset`'s version. Can be anything from a simple
263
    #: semantic versioning string like "2.1.3", to a more complex
264
    #: string, like for example "2021-01-01.schleswig-holstein.0" for
265
    #: OpenStreetMap data.
266
    #: Note that the latter encodes the :class:`Dataset`'s date, region
267
    #: and a sequential number in case the data changes without the date
268
    #: or region changing, for example due to implementation changes.
269
    version: str
270
    #: The sources used by the datasets.
271
    #: Could be tables, files and urls
272
    sources: DatasetSources = field(init=False)
273
    #: The targets created by the datasets.
274
    #: Could be tables and files
275
    targets: DatasetTargets = field(init=False)
276
    #: The first task(s) of this :class:`Dataset` will be marked as
277
    #: downstream of any of the listed dependencies. In case of bare
278
    #: :class:`Task`, a direct link will be created whereas for a
279
    #: :class:`Dataset` the link will be made to all of its last tasks.
280
    dependencies: Dependencies = ()
281
    #: The tasks of this :class:`Dataset`. A :class:`TaskGraph` will
282
    #: automatically be converted to :class:`Tasks_`.
283
    tasks: Tasks = ()
284
285
    def check_version(self, after_execution=()):
286
        scenario_names = config.settings()["egon-data"]["--scenarios"]
287
288
        def skip_task(task, *xs, **ks):
289
            with db.session_scope() as session:
290
                datasets = session.query(Model).filter_by(name=self.name).all()
291
                if (
292
                    self.version in [ds.version for ds in datasets]
293
                    and all(
294
                        scenario_names == ds.scenarios.replace(
295
                            "{", "").replace("}", "").split(",")
296
                        for ds in datasets
297
                    )
298
                    and not re.search(r"\.dev$", self.version)
299
                ):
300
                    logger.info(
301
                        f"Dataset '{self.name}' version '{self.version}'"
302
                        f" scenarios {scenario_names}"
303
                        f" already executed. Skipping."
304
                    )
305
                else:
306
                    for ds in datasets:
307
                        session.delete(ds)
308
                    result = super(type(task), task).execute(*xs, **ks)
309
                    for function in after_execution:
310
                        function(session)
311
                    return result
312
313
        return skip_task
314
315
    def update(self, session):
316
        dataset = Model(
317
            name=self.name,
318
            version=self.version,
319
            scenarios=config.settings()["egon-data"]["--scenarios"],
320
            sources=self.sources.to_dict() if hasattr(self.sources, "to_dict") else dict(self.sources),
321
            targets=self.targets.to_dict() if hasattr(self.targets, "to_dict") else dict(self.targets),
322
        )
323
324
        dependencies = (
325
            session.query(Model)
326
            .filter(
327
                tuple_(Model.name, Model.version).in_(
328
                    [
329
                        (dataset.name, dataset.version)
330
                        for dependency in self.dependencies
331
                        if isinstance(dependency, Dataset)
332
                        or hasattr(dependency, "dataset")
333
                        for dataset in [
334
                            dependency.dataset
335
                            if isinstance(dependency, Operator)
336
                            else dependency
337
                        ]
338
                    ]
339
                )
340
            )
341
            .all()
342
        )
343
        dataset.dependencies = dependencies
344
        session.add(dataset)
345
346
    def __post_init__(self):
347
        self.dependencies = list(self.dependencies)
348
349
        class_sources = getattr(type(self), "sources", None)
350
351
        if not isinstance(class_sources, DatasetSources):
352
            logger.warning(
353
                f"Dataset '{type(self).__name__}' has no valid class-level 'sources' attribute. "
354
                "Defaulting to empty DatasetSources().",
355
                stacklevel=2
356
            )
357
            self.sources = DatasetSources()
358
        else:
359
            self.sources = class_sources
360
            if self.sources.empty():
361
                logger.warning(
362
                    f"Dataset '{type(self).__name__}' defines 'sources', but it is empty. "
363
                    "Please check if this is intentional.",
364
                    stacklevel=2
365
                )
366
367
368
        class_targets = getattr(type(self), "targets", None)
369
370
        if not isinstance(class_targets, DatasetTargets):
371
            logger.warning(
372
                f"Dataset '{type(self).__name__}' has no valid class-level 'targets' attribute. "
373
                "Defaulting to empty DatasetTargets().",
374
                stacklevel=2
375
            )
376
            self.targets = DatasetTargets()
377
        else:
378
            self.targets = class_targets
379
            if self.targets.empty():
380
                logger.warning(
381
                    f"Dataset '{type(self).__name__}' defines 'targets', but it is empty. "
382
                    "Please check if this is intentional.",
383
                    stacklevel=2
384
                )
385
        if not isinstance(self.tasks, Tasks_):
386
            self.tasks = Tasks_(self.tasks)
387
        if len(self.tasks.last) > 1:
388
            # Explicitly create single final task, because we can't know
389
            # which of the multiple tasks finishes last.
390
            name = prefix(self)
391
            name = f"{name if name else f'{self.__module__}.'}{self.name}."
392
            update_version = PythonOperator(
393
                task_id=f"{name}update-version",
394
                # Do nothing, because updating will be added later.
395
                python_callable=lambda *xs, **ks: None,
396
            )
397
            self.tasks = Tasks_((self.tasks.graph, update_version))
398
        # Due to the `if`-block above, there'll now always be exactly
399
        # one task in `self.tasks.last` which the next line just
400
        # selects.
401
        last = list(self.tasks.last)[0]
402
        for task in self.tasks.values():
403
            task.dataset = self
404
            cls = task.__class__
405
            versioned = type(
406
                f"{self.name[0].upper()}{self.name[1:]} (versioned)",
407
                (cls,),
408
                {
409
                    "execute": self.check_version(
410
                        after_execution=[self.update] if task is last else []
411
                    )
412
                },
413
            )
414
            task.__class__ = versioned
415
416
        predecessors = [
417
            task
418
            for dataset in self.dependencies
419
            if isinstance(dataset, Dataset)
420
            for task in dataset.tasks.last
421
        ] + [task for task in self.dependencies if isinstance(task, Operator)]
422
        for p in predecessors:
423
            for first in self.tasks.first:
424
                p.set_downstream(first)
425
426
        self.register()
427
428
    def __init_subclass__(cls) -> None:
429
        # Warn about missing or invalid class attributes
430
        if not isinstance(getattr(cls, "sources", None), DatasetSources):
431
            logger.warning(
432
                f"Dataset '{cls.__name__}' does not define a valid class-level 'sources'.",
433
                stacklevel=2
434
            )
435
        if not isinstance(getattr(cls, "targets", None), DatasetTargets):
436
            logger.warning(
437
                f"Dataset '{cls.__name__}' does not define a valid class-level 'targets'.",
438
                stacklevel=2
439
            )
440
441
    def register(self):
442
        with db.session_scope() as session:
443
            existing = session.query(Model).filter_by(
444
                name=self.name
445
            ).first()
446
447
            if not existing:
448
                entry = Model(
449
                    name=self.name,
450
                    version="will be filled after execution",
451
                    scenarios="{}",
452
                    sources=self.sources.to_dict(),
453
                    targets=self.targets.to_dict()
454
                )
455
                session.add(entry)
456
457
def load_sources_and_targets(
458
    name: str,
459
) -> tuple[DatasetSources, DatasetTargets]:
460
    """
461
    Load DatasetSources and DatasetTargets from the datasets table.
462
463
    Parameters
464
    ----------
465
        name (str): Name of the dataset.
466
467
    Returns
468
    -------
469
        Tuple[DatasetSources, DatasetTargets]
470
    """
471
    with db.session_scope() as session:
472
        dataset_entry = (
473
            session.query(Model)
474
            .filter_by(name=name)
475
            .first()
476
        )
477
478
        if dataset_entry is None:
479
            raise ValueError(f"Dataset '{name}' not found in the database.")
480
481
        # Extract raw JSON dicts within the session
482
        raw_sources = dict(dataset_entry.sources or {})
483
        raw_targets = dict(dataset_entry.targets or {})
484
485
    # Recreate objects *outside the session* (now safe)
486
    sources = DatasetSources(**raw_sources)
487
    targets = DatasetTargets(**raw_targets)
488
489
    return sources, targets
490
491
492
def export_dataset_io_to_json(
493
    output_path: str = "dataset_io_overview.json",
494
) -> None:
495
    """
496
    Export all sources and targets of datasets to a JSON file.
497
498
    Parameters
499
    ----------
500
    output_path : str
501
        Path to the output JSON file.
502
    """
503
504
    result = {}
505
506
    with db.session_scope() as session:
507
        datasets = session.query(Model).all()
508
509
        for dataset in datasets:
510
            name = dataset.name
511
512
            try:
513
                raw_sources = dict(dataset.sources or {})
514
                raw_targets = dict(dataset.targets or {})
515
516
                result[name] = {
517
                    "sources": raw_sources,
518
                    "targets": raw_targets,
519
                }
520
            except Exception as e:
521
                print(f"⚠️ Could not process dataset '{name}': {e}")
522
523
    # Save to JSON
524
    output_file = Path(output_path)
525
    output_file.write_text(json.dumps(result, indent=2, ensure_ascii=False))
526
    print(f"✅ Dataset I/O overview written to {output_file.resolve()}")
527