Completed
Push — main ( f62cdb...aa7cf2 )
by Chaitanya
23s queued 15s
created

asgardpy.base.base.TimeInterval.build()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 1
dl 0
loc 9
rs 10
c 0
b 0
f 0
1
"""
2
Classes containing the Base for the Analysis steps and some Basic Config types.
3
"""
4
5
import html
6
from dataclasses import dataclass
7
from enum import Enum
8
from pathlib import Path
9
from typing import Annotated
10
11
from astropy import units as u
12
from astropy.time import Time
13
from pydantic import (
14
    BaseModel,
15
    BeforeValidator,
16
    ConfigDict,
17
    GetCoreSchemaHandler,
18
    PlainSerializer,
19
)
20
from pydantic_core import core_schema
21
22
__all__ = [
23
    "AngleType",
24
    "BaseConfig",
25
    "EnergyRangeConfig",
26
    "EnergyType",
27
    "FrameEnum",
28
    "PathType",
29
    "TimeFormatEnum",
30
    "TimeInterval",
31
]
32
33
34
# Base Angle Type Quantity
35
def validate_angle_type(v: str) -> u.Quantity:
36
    """Validation for Base Angle Type Quantity"""
37
    if isinstance(v, u.Quantity):
38
        v_ = v
39
    elif isinstance(v, str):
40
        v_ = u.Quantity(v)
41
    if v_.unit.physical_type != "angle":
42
        raise ValueError(f"Invalid unit for angle: {v_.unit!r}")
43
    else:
44
        return v_
45
46
47
AngleType = Annotated[
48
    str | u.Quantity,
49
    BeforeValidator(validate_angle_type),
50
    PlainSerializer(lambda x: f"{x.value} {x.unit}", when_used="json-unless-none", return_type=str),
51
]
52
53
54
# Base Energy Type Quantity
55
def validate_energy_type(v: str) -> u.Quantity:
56
    """Validation for Base Energy Type Quantity"""
57
    if isinstance(v, u.Quantity):
58
        v_ = v
59
    elif isinstance(v, str):
60
        v_ = u.Quantity(v)
61
    if v_.unit.physical_type != "energy":
62
        raise ValueError(f"Invalid unit for energy: {v_.unit!r}")
63
    else:
64
        return v_
65
66
67
EnergyType = Annotated[
68
    str | u.Quantity,
69
    BeforeValidator(validate_energy_type),
70
    PlainSerializer(lambda x: f"{x.value} {x.unit}", when_used="json-unless-none", return_type=str),
71
]
72
73
74
# Base Path Type Quantity
75
def validate_path_type(v: str) -> Path:
76
    """Validation for Base Path Type Quantity"""
77
    if v == "None":
78
        return Path(".")
79
    else:
80
        path_ = Path(v).resolve()
81
        # Only check if the file location or directory path exists
82
        if path_.is_file():
83
            path_ = path_.parent
84
85
        if path_.exists():
86
            return Path(v)
87
        else:
88
            raise ValueError(f"Path {v} does not exist")
89
90
91
PathType = Annotated[
92
    str | Path,
93
    BeforeValidator(validate_path_type),
94
    PlainSerializer(lambda x: Path(x), when_used="json-unless-none", return_type=Path),
95
]
96
97
98
class FrameEnum(str, Enum):
99
    """Config section for list of frames on creating a SkyCoord object."""
100
101
    icrs = "icrs"
102
    galactic = "galactic"
103
104
105
class TimeFormatEnum(str, Enum):
106
    """Config section for list of formats for creating a Time object."""
107
108
    datetime = "datetime"
109
    fits = "fits"
110
    iso = "iso"
111
    isot = "isot"
112
    jd = "jd"
113
    mjd = "mjd"
114
    unix = "unix"
115
116
117
@dataclass
118
class TimeInterval:
119
    """
120
    Config section for getting main information for creating a Time Interval
121
    object.
122
    """
123
124
    interval: dict[str, str | float]
125
126
    def build(self) -> dict:
127
        value_dict = {}
128
        value_dict["format"] = Time(self.interval["start"]).format
129
130
        value_dict["start"] = str(self.interval["start"])
131
132
        value_dict["stop"] = str(self.interval["stop"])
133
134
        return value_dict
135
136
    @classmethod
137
    def __get_pydantic_core_schema__(
138
        cls, source: type[dict], handler: GetCoreSchemaHandler
139
    ) -> core_schema.CoreSchema:
140
        assert source is TimeInterval
141
        return core_schema.no_info_after_validator_function(
142
            cls._validate,
143
            core_schema.dict_schema(keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema()),
144
            serialization=core_schema.plain_serializer_function_ser_schema(
145
                cls._serialize,
146
                info_arg=False,
147
                return_schema=core_schema.dict_schema(
148
                    keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema()
149
                ),
150
            ),
151
        )
152
153
    @staticmethod
154
    def _validate(value: dict) -> "TimeInterval":
155
        inv_dict: dict[str, str | float] = {}
156
157
        if isinstance(value["format"], TimeFormatEnum):
158
            inv_dict["format"] = value["format"]
159
160
        # Read all values as string
161
        value["start"] = str(value["start"])
162
        value["stop"] = str(value["stop"])
163
164
        if not Time(value["start"], format=value["format"]):
165
            raise ValueError(f"{value['start']} is not the right Time value for format {value['format']}")
166
        else:
167
            inv_dict["start"] = Time(value["start"], format=value["format"])
168
169
        if not Time(value["stop"], format=value["format"]):
170
            raise ValueError(f"{value['stop']} is not the right Time value for format {value['format']}")
171
        else:
172
            inv_dict["stop"] = Time(value["stop"], format=value["format"])
173
174
        return TimeInterval(inv_dict)
175
176
    @staticmethod
177
    def _serialize(value: "TimeInterval") -> dict:
178
        return value.build()
179
180
181
class BaseConfig(BaseModel):
182
    """
183
    Base Config class for creating other Config sections with specific encoders.
184
    """
185
186
    model_config = ConfigDict(
187
        arbitrary_types_allowed=True,
188
        validate_assignment=True,
189
        extra="forbid",
190
        validate_default=True,
191
        use_enum_values=True,
192
    )
193
194
    def _repr_html_(self):
195
        try:
196
            return self.to_html()
197
        except AttributeError:
198
            return f"<pre>{html.escape(str(self))}</pre>"
199
200
201
# Basic Quantity ranges Type for building the Config
202
class EnergyRangeConfig(BaseConfig):
203
    """
204
    Config section for getting a energy range information for creating an
205
    Energy type Quantity object.
206
    """
207
208
    min: EnergyType = 1 * u.GeV
209
    max: EnergyType = 1 * u.TeV
210