asgardpy.base.base   A
last analyzed

Complexity

Total Complexity 18

Size/Duplication

Total Lines 202
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 120
dl 0
loc 202
rs 10
c 0
b 0
f 0
wmc 18

3 Functions

Rating   Name   Duplication   Size   Complexity  
A validate_angle_type() 0 10 4
A validate_energy_type() 0 10 4
A validate_path_type() 0 14 4

5 Methods

Rating   Name   Duplication   Size   Complexity  
A TimeInterval.build() 0 9 1
A TimeInterval.__get_pydantic_core_schema__() 0 13 1
A TimeInterval._serialize() 0 3 1
A TimeInterval._validate() 0 14 1
A BaseConfig._repr_html_() 0 5 2
1
"""
2
Classes containing the Base for the Analysis steps and some Basic Config types.
3
"""
4
5
import html
6
from dataclasses import dataclass
7
from enum import Enum
8
from pathlib import Path
9
from typing import Annotated
10
11
from astropy import units as u
12
from astropy.time import Time
13
from pydantic import (
14
    BaseModel,
15
    BeforeValidator,
16
    ConfigDict,
17
    GetCoreSchemaHandler,
18
    PlainSerializer,
19
)
20
from pydantic_core import core_schema
21
22
__all__ = [
23
    "AngleType",
24
    "BaseConfig",
25
    "EnergyRangeConfig",
26
    "EnergyType",
27
    "FrameEnum",
28
    "PathType",
29
    "TimeFormatEnum",
30
    "TimeInterval",
31
]
32
33
34
# Base Angle Type Quantity
35
def validate_angle_type(v: str) -> u.Quantity:
36
    """Validation for Base Angle Type Quantity"""
37
    if isinstance(v, u.Quantity):
38
        v_ = v
39
    elif isinstance(v, str):
40
        v_ = u.Quantity(v)
41
    if v_.unit.physical_type != "angle":
42
        raise ValueError(f"Invalid unit for angle: {v_.unit!r}")
43
    else:
44
        return v_
45
46
47
AngleType = Annotated[
48
    str | u.Quantity,
49
    BeforeValidator(validate_angle_type),
50
    PlainSerializer(lambda x: f"{x.value} {x.unit}", when_used="json-unless-none", return_type=str),
51
]
52
53
54
# Base Energy Type Quantity
55
def validate_energy_type(v: str) -> u.Quantity:
56
    """Validation for Base Energy Type Quantity"""
57
    if isinstance(v, u.Quantity):
58
        v_ = v
59
    elif isinstance(v, str):
60
        v_ = u.Quantity(v)
61
    if v_.unit.physical_type != "energy":
62
        raise ValueError(f"Invalid unit for energy: {v_.unit!r}")
63
    else:
64
        return v_
65
66
67
EnergyType = Annotated[
68
    str | u.Quantity,
69
    BeforeValidator(validate_energy_type),
70
    PlainSerializer(lambda x: f"{x.value} {x.unit}", when_used="json-unless-none", return_type=str),
71
]
72
73
74
# Base Path Type Quantity
75
def validate_path_type(v: str) -> Path:
76
    """Validation for Base Path Type Quantity"""
77
    if v == "None":
78
        return Path(".")
79
    else:
80
        path_ = Path(v).resolve()
81
        # Only check if the file location or directory path exists
82
        if path_.is_file():
83
            path_ = path_.parent
84
85
        if path_.exists():
86
            return Path(v)
87
        else:
88
            raise ValueError(f"Path {v} does not exist")
89
90
91
PathType = Annotated[
92
    str | Path,
93
    BeforeValidator(validate_path_type),
94
    PlainSerializer(lambda x: Path(x), when_used="json-unless-none", return_type=Path),
95
]
96
97
98
class FrameEnum(str, Enum):
99
    """Config section for list of frames on creating a SkyCoord object."""
100
101
    icrs = "icrs"
102
    galactic = "galactic"
103
104
105
class TimeFormatEnum(str, Enum):
106
    """Config section for list of formats for creating a Time object."""
107
108
    datetime = "datetime"
109
    fits = "fits"
110
    iso = "iso"
111
    isot = "isot"
112
    jd = "jd"
113
    mjd = "mjd"
114
    unix = "unix"
115
116
117
@dataclass
118
class TimeInterval:
119
    """
120
    Config section for getting main information for creating a Time Interval
121
    object.
122
    """
123
124
    interval: dict[str, str | float]
125
126
    def build(self) -> dict:
127
        value_dict = {}
128
        value_dict["format"] = Time(self.interval["start"]).format
129
130
        value_dict["start"] = str(self.interval["start"])
131
132
        value_dict["stop"] = str(self.interval["stop"])
133
134
        return value_dict
135
136
    @classmethod
137
    def __get_pydantic_core_schema__(
138
        cls, source: type[dict], handler: GetCoreSchemaHandler
139
    ) -> core_schema.CoreSchema:
140
        assert source is TimeInterval
141
        return core_schema.no_info_after_validator_function(
142
            cls._validate,
143
            core_schema.dict_schema(keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema()),
144
            serialization=core_schema.plain_serializer_function_ser_schema(
145
                cls._serialize,
146
                info_arg=False,
147
                return_schema=core_schema.dict_schema(
148
                    keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema()
149
                ),
150
            ),
151
        )
152
153
    @staticmethod
154
    def _validate(value: dict) -> "TimeInterval":
155
        inv_dict: dict[str, str | float] = {}
156
157
        inv_dict["format"] = value["format"]
158
159
        # Read all values as string
160
        value["start"] = str(value["start"])
161
        value["stop"] = str(value["stop"])
162
163
        inv_dict["start"] = Time(value["start"], format=value["format"])
164
        inv_dict["stop"] = Time(value["stop"], format=value["format"])
165
166
        return TimeInterval(inv_dict)
167
168
    @staticmethod
169
    def _serialize(value: "TimeInterval") -> dict:
170
        return value.build()
171
172
173
class BaseConfig(BaseModel):
174
    """
175
    Base Config class for creating other Config sections with specific encoders.
176
    """
177
178
    model_config = ConfigDict(
179
        arbitrary_types_allowed=True,
180
        validate_assignment=True,
181
        extra="forbid",
182
        validate_default=True,
183
        use_enum_values=True,
184
    )
185
186
    def _repr_html_(self):  # pragma: no cover
187
        try:
188
            return self.to_html()
189
        except AttributeError:
190
            return f"<pre>{html.escape(str(self))}</pre>"
191
192
193
# Basic Quantity ranges Type for building the Config
194
class EnergyRangeConfig(BaseConfig):
195
    """
196
    Config section for getting a energy range information for creating an
197
    Energy type Quantity object.
198
    """
199
200
    min: EnergyType = 1 * u.GeV
201
    max: EnergyType = 1 * u.TeV
202