Passed
Pull Request — main (#165)
by Chaitanya
01:47
created

asgardpy.base.base.PathType.validate()   A

Complexity

Conditions 4

Size

Total Lines 14
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 10
nop 2
dl 0
loc 14
rs 9.9
c 0
b 0
f 0
1
"""
2
Classes containing the Base for the Analysis steps and some Basic Config types.
3
"""
4
5
from dataclasses import dataclass
6
from enum import Enum
7
from pathlib import Path
8
from typing import Annotated
9
10
from astropy import units as u
11
from astropy.time import Time
12
from pydantic import (
13
    AfterValidator,
14
    BaseModel,
15
    ConfigDict,
16
    GetCoreSchemaHandler,
17
    PlainSerializer,
18
    WithJsonSchema,
19
)
20
from pydantic_core import core_schema
21
22
__all__ = [
23
    "AngleType",
24
    "BaseConfig",
25
    "EnergyRangeConfig",
26
    "EnergyType",
27
    "FrameEnum",
28
    "PathType",
29
    "TimeFormatEnum",
30
    "TimeInterval",
31
]
32
33
34
# Basic Quantities Type for building the Config
35
def validate_angle_type(v: str) -> u.Quantity:
36
    """Validation for Base Angle Type Quantity"""
37
    if isinstance(v, u.Quantity):
38
        v_ = v
39
    elif isinstance(v, str):
40
        v_ = u.Quantity(v)
41
    if v_.unit.physical_type != "angle":
1 ignored issue
show
introduced by
The variable v_ does not seem to be defined for all execution paths.
Loading history...
42
        raise ValueError(f"Invalid unit for angle: {v_.unit!r}")
43
    else:
44
        return v_
45
46
47
# Base Angle Type Quantity
48
AngleType = Annotated[
49
    str | u.Quantity,
50
    AfterValidator(validate_angle_type),
51
    PlainSerializer(lambda x: u.Quantity(x), return_type=u.Quantity),
52
    WithJsonSchema({"type": "string"}, mode="serialization"),
53
]
54
55
56
def validate_energy_type(v: str) -> u.Quantity:
57
    """Validation for Base Energy Type Quantity"""
58
    if isinstance(v, u.Quantity):
59
        v_ = v
60
    elif isinstance(v, str):
61
        v_ = u.Quantity(v)
62
    if v_.unit.physical_type != "energy":
1 ignored issue
show
introduced by
The variable v_ does not seem to be defined for all execution paths.
Loading history...
63
        raise ValueError(f"Invalid unit for energy: {v_.unit!r}")
64
    else:
65
        return v_
66
67
68
# Base Energy Type Quantity
69
EnergyType = Annotated[
70
    str | u.Quantity,
71
    AfterValidator(validate_energy_type),
72
    PlainSerializer(lambda x: u.Quantity(x), return_type=u.Quantity),
73
    WithJsonSchema({"type": "string"}, mode="serialization"),
74
]
75
76
77
def validate_path_type(v: str) -> Path:
78
    """Validation for Base Path Type Quantity"""
79
    if v == "None":
80
        return Path(".")
81
    else:
82
        path_ = Path(v).resolve()
83
        # Only check if the file location or directory path exists
84
        if path_.is_file():
85
            path_ = path_.parent
86
87
        if path_.exists():
88
            return Path(v)
89
        else:
90
            raise ValueError(f"Path {v} does not exist")
91
92
93
PathType = Annotated[
94
    str | Path,
95
    AfterValidator(validate_path_type),
96
    PlainSerializer(lambda x: Path(x), return_type=Path),
97
    WithJsonSchema({"type": "string"}, mode="serialization"),
98
]
99
100
101
class FrameEnum(str, Enum):
102
    """Config section for list of frames on creating a SkyCoord object."""
103
104
    icrs = "icrs"
105
    galactic = "galactic"
106
107
108
class TimeFormatEnum(str, Enum):
109
    """Config section for list of formats for creating a Time object."""
110
111
    datetime = "datetime"
112
    fits = "fits"
113
    iso = "iso"
114
    isot = "isot"
115
    jd = "jd"
116
    mjd = "mjd"
117
    unix = "unix"
118
119
120
@dataclass
121
class TimeInterval:
122
    """
123
    Config section for getting main information for creating a Time Interval
124
    object.
125
    """
126
127
    interval: dict[str, str | float]
128
129
    def build(self) -> dict:
130
        # self.interval["start"] = Time(self.interval["start"], format=self.interval["format"])
131
        # self.interval["stop"] = Time(self.interval["stop"], format=self.interval["format"])
132
        value_dict = {}
133
        value_dict["format"] = Time(self.interval["start"]).format
134
135
        value_dict["start"] = str(self.interval["start"])
136
137
        value_dict["stop"] = str(self.interval["stop"])
138
139
        return value_dict
140
141
    @classmethod
142
    def __get_pydantic_core_schema__(
143
        cls, source: type[dict], handler: GetCoreSchemaHandler
144
    ) -> core_schema.CoreSchema:
145
        assert source is TimeInterval
146
        return core_schema.no_info_after_validator_function(
147
            cls._validate,
148
            core_schema.dict_schema(keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema()),
149
            serialization=core_schema.plain_serializer_function_ser_schema(
150
                cls._serialize,
151
                info_arg=False,
152
                return_schema=core_schema.dict_schema(
153
                    keys_schema=core_schema.str_schema(), values_schema=core_schema.str_schema()
154
                ),
155
            ),
156
        )
157
158
    @staticmethod
159
    def _validate(value: dict) -> "TimeInterval":
160
        inv_dict: dict[str, str | float] = {}
161
162
        if isinstance(value["format"], TimeFormatEnum):
163
            inv_dict["format"] = value["format"]
164
165
        # Read all values as string
166
        value["start"] = str(value["start"])
167
        value["stop"] = str(value["stop"])
168
169
        if not Time(value["start"], format=value["format"]):
170
            raise ValueError(f"{value['start']} is not the right Time value for format {value['format']}")
171
        else:
172
            inv_dict["start"] = Time(value["start"], format=value["format"])
173
174
        if not Time(value["stop"], format=value["format"]):
175
            raise ValueError(f"{value['stop']} is not the right Time value for format {value['format']}")
176
        else:
177
            inv_dict["stop"] = Time(value["stop"], format=value["format"])
178
179
        return TimeInterval(inv_dict)
180
181
    @staticmethod
182
    def _serialize(value: "TimeInterval") -> dict:
183
        return value.build()
184
185
186
class BaseConfig(BaseModel):
187
    """
188
    Base Config class for creating other Config sections with specific encoders.
189
    """
190
191
    model_config = ConfigDict(
192
        arbitrary_types_allowed=True,
193
        validate_assignment=True,
194
        extra="forbid",
195
        validate_default=True,
196
        use_enum_values=True,
197
        json_encoders={u.Quantity: lambda v: f"{v.value} {v.unit}"},
198
    )
199
200
201
# Basic Quantity ranges Type for building the Config
202
class EnergyRangeConfig(BaseConfig):
203
    """
204
    Config section for getting a energy range information for creating an
205
    Energy type Quantity object.
206
    """
207
208
    min: EnergyType = 1 * u.GeV
209
    max: EnergyType = 1 * u.TeV
210