Passed
Pull Request — dev (#1375)
by
unknown
02:15
created

data.validation_utils   A

Complexity

Total Complexity 20

Size/Duplication

Total Lines 206
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
wmc 20
eloc 84
dl 0
loc 206
rs 10
c 0
b 0
f 0

3 Functions

Rating   Name   Duplication   Size   Complexity  
A _resolve_rule_params() 0 25 5
A _resolve_context_value() 0 44 5
C create_validation_tasks() 0 122 10
1
"""Airflow integration for egon-validation."""
2
3
from typing import Any, Dict, List
4
from airflow.operators.python import PythonOperator
5
from egon_validation import run_validations, RunContext
6
from egon_validation.rules.base import Rule
7
import logging
8
9
logger = logging.getLogger(__name__)
10
11
12
def _resolve_context_value(value: Any, boundary: str, scenarios: List[str]) -> Any:
13
    """Resolve a value that may be context-dependent (boundary/scenario).
14
15
    Args:
16
        value: The value to resolve. Can be:
17
            - A dict with boundary keys: {"Schleswig-Holstein": 27, "Everything": 537}
18
            - A dict with scenario keys: {"eGon2035": 100, "eGon100RE": 200}
19
            - Any other value (returned as-is)
20
        boundary: Current dataset boundary setting
21
        scenarios: List of active scenarios
22
23
    Returns:
24
        Resolved value based on current context
25
26
    Examples:
27
        >>> _resolve_context_value({"Schleswig-Holstein": 27, "Everything": 537},
28
        ...                        "Schleswig-Holstein", ["eGon2035"])
29
        27
30
31
        >>> _resolve_context_value({"eGon2035": 100, "eGon100RE": 200},
32
        ...                        "Everything", ["eGon2035"])
33
        100
34
35
        >>> _resolve_context_value(42, "Everything", ["eGon2035"])
36
        42
37
    """
38
    # If not a dict, return as-is
39
    if not isinstance(value, dict):
40
        return value
41
42
    # Try to resolve by boundary
43
    if boundary in value:
44
        logger.debug(f"Resolved boundary-dependent value: {boundary} -> {value[boundary]}")
45
        return value[boundary]
46
47
    # Try to resolve by scenario
48
    for scenario in scenarios:
49
        if scenario in value:
50
            logger.debug(f"Resolved scenario-dependent value: {scenario} -> {value[scenario]}")
51
            return value[scenario]
52
53
    # If dict doesn't match boundary/scenario pattern, return as-is
54
    # This handles cases like column_types dicts which are not context-dependent
55
    return value
56
57
58
def _resolve_rule_params(rule: Rule, boundary: str, scenarios: List[str]) -> None:
59
    """Recursively resolve context-dependent parameters in a rule.
60
61
    Modifies rule.params in-place, resolving any dict values that match
62
    boundary or scenario patterns.
63
64
    Args:
65
        rule: The validation rule to process
66
        boundary: Current dataset boundary setting
67
        scenarios: List of active scenarios
68
    """
69
    if not hasattr(rule, 'params') or not isinstance(rule.params, dict):
70
        return
71
72
    # Recursively resolve all parameter values
73
    for param_name, param_value in rule.params.items():
74
        resolved_value = _resolve_context_value(param_value, boundary, scenarios)
75
76
        # If the value was resolved (changed), update it
77
        if resolved_value is not param_value:
78
            logger.info(
79
                f"Rule {rule.rule_id}: Resolved {param_name} for "
80
                f"boundary='{boundary}', scenarios={scenarios}"
81
            )
82
            rule.params[param_name] = resolved_value
83
84
def create_validation_tasks(
85
    validation_dict: Dict[str, List[Rule]],
86
    dataset_name: str,
87
    on_failure: str = "continue"
88
) -> List[PythonOperator]:
89
    """Convert validation dict to Airflow tasks.
90
91
    Automatically resolves context-dependent parameters in validation rules.
92
    Parameters can be specified as dicts with boundary or scenario keys:
93
94
    - Boundary-dependent: {"Schleswig-Holstein": 27, "Everything": 537}
95
    - Scenario-dependent: {"eGon2035": 100, "eGon100RE": 200}
96
97
    The appropriate value is selected based on the current configuration.
98
99
    Args:
100
        validation_dict: {"task_name": [Rule1(), Rule2()]}
101
        dataset_name: Name of dataset
102
        on_failure: "continue" or "fail"
103
104
    Returns:
105
        List of PythonOperator tasks
106
107
    Example:
108
        >>> validation_dict = {
109
        ...     "data_quality": [
110
        ...         RowCountValidation(
111
        ...             table="boundaries.vg250_krs",
112
        ...             rule_id="TEST_ROW_COUNT",
113
        ...             expected_count={"Schleswig-Holstein": 27, "Everything": 537}
114
        ...         )
115
        ...     ]
116
        ... }
117
        >>> tasks = create_validation_tasks(validation_dict, "VG250")
118
    """
119
    if not validation_dict:
120
        return []
121
122
    tasks = []
123
124
    for task_name, rules in validation_dict.items():
125
        def make_callable(rules, task_name):
126
            def run_validation(**context):
127
                import os
128
                import time
129
                from datetime import datetime
130
                from egon.data import db as egon_db
131
                from egon.data.config import settings
132
133
                # Use same run_id as validation report for consistency
134
                # This allows the validation report to collect results from all validation tasks
135
                run_id = (
136
                    os.environ.get('AIRFLOW_CTX_DAG_RUN_ID') or
137
                    context.get('run_id') or
138
                    (context.get('ti') and hasattr(context['ti'], 'dag_run') and context['ti'].dag_run.run_id) or
139
                    (context.get('dag_run') and context['dag_run'].run_id) or
140
                    f"airflow-{dataset_name}-{task_name}-{int(time.time())}"
141
                )
142
143
                # Use absolute path to ensure consistent location regardless of working directory
144
                # Priority: EGON_VALIDATION_DIR env var > current working directory
145
                out_dir = os.path.join(
146
                    os.environ.get('EGON_VALIDATION_DIR', os.getcwd()),
147
                    "validation_runs"
148
                )
149
150
                # Include execution timestamp in task name so retries write to separate directories
151
                # The validation report will filter to keep only the most recent execution per task
152
                execution_date = context.get('execution_date') or datetime.now()
153
                timestamp = execution_date.strftime('%Y%m%dT%H%M%S')
154
                full_task_name = f"{dataset_name}.{task_name}.{timestamp}"
155
156
                logger.info(f"Validation: {full_task_name} (run_id: {run_id})")
157
158
                # Use existing engine from egon.data.db
159
                engine = egon_db.engine()
160
161
                # Get current configuration context
162
                config = settings()["egon-data"]
163
                boundary = config["--dataset-boundary"]
164
                scenarios = config.get("--scenarios", [])
165
166
                logger.info(f"Resolving validation parameters for boundary='{boundary}', scenarios={scenarios}")
167
168
                # Set task and dataset on all rules (required by Rule base class)
169
                # Also resolve context-dependent parameters
170
                for rule in rules:
171
                    if not hasattr(rule, 'task') or rule.task is None:
172
                        rule.task = task_name
173
                    if not hasattr(rule, 'dataset') or rule.dataset is None:
174
                        rule.dataset = dataset_name
175
176
                    # Automatically resolve boundary/scenario-dependent parameters
177
                    _resolve_rule_params(rule, boundary, scenarios)
178
179
                ctx = RunContext(run_id=run_id, source="airflow", out_dir=out_dir)
180
                results = run_validations(engine, ctx, rules, full_task_name)
181
182
                total = len(results)
183
                failed = sum(1 for r in results if not r.success)
184
185
                logger.info(f"Complete: {total - failed}/{total} passed")
186
187
                if failed > 0 and on_failure == "fail":
188
                    raise Exception(f"{failed}/{total} validations failed")
189
190
                return {"total": total, "passed": total - failed, "failed": failed}
191
192
            return run_validation
193
194
        func = make_callable(rules, task_name)
195
        func.__name__ = f"validate_{task_name}"
196
197
        operator = PythonOperator(
198
            task_id=f"{dataset_name}.validate.{task_name}",
199
            python_callable=func,
200
            provide_context=True,
201
        )
202
203
        tasks.append(operator)
204
205
    return tasks
206