1
|
|
|
from typing import Any, Dict, Optional, Type |
2
|
|
|
|
3
|
|
|
from abc import ABC |
4
|
|
|
from datetime import datetime |
5
|
|
|
|
6
|
|
|
from bson import ObjectId |
7
|
|
|
from bson.objectid import InvalidId |
8
|
|
|
from pydantic import BaseConfig, BaseModel |
9
|
|
|
|
10
|
|
|
|
11
|
|
|
class OID: |
12
|
|
|
@classmethod |
13
|
|
|
def __get_validators__(cls): |
14
|
|
|
yield cls.validate |
15
|
|
|
|
16
|
|
|
@classmethod |
17
|
|
|
def validate(cls, v): |
18
|
|
|
try: |
19
|
|
|
return ObjectId(str(v)) |
20
|
|
|
except InvalidId: |
21
|
|
|
raise ValueError("Invalid object ID") |
22
|
|
|
|
23
|
|
|
|
24
|
|
|
class MongoDBModel(BaseModel, ABC): |
25
|
|
|
|
26
|
|
|
id: Optional[OID] |
27
|
|
|
|
28
|
|
|
class Config(BaseConfig): |
29
|
|
|
allow_population_by_field_name = True |
30
|
|
|
json_encoders = { |
31
|
|
|
datetime: lambda dt: dt.isoformat(), |
32
|
|
|
ObjectId: str, |
33
|
|
|
} |
34
|
|
|
|
35
|
|
|
@classmethod |
36
|
|
|
def from_mongo(cls, data: Dict[str, Any]) -> Optional[Type["MongoDBModel"]]: |
37
|
|
|
"""Constructs a pydantic object from mongodb compatible dictionary""" |
38
|
|
|
if not data: |
39
|
|
|
return None |
40
|
|
|
|
41
|
|
|
id = data.pop("_id", None) # Convert _id into id |
42
|
|
|
return cls(**dict(data, id=id)) |
43
|
|
|
|
44
|
|
|
def to_mongo(self, **kwargs): |
45
|
|
|
"""Maps a pydantic model to a mongodb compatible dictionary""" |
46
|
|
|
|
47
|
|
|
exclude_unset = kwargs.pop( |
48
|
|
|
"exclude_unset", |
49
|
|
|
False, # Set as false so that default values are also stored |
50
|
|
|
) |
51
|
|
|
|
52
|
|
|
by_alias = kwargs.pop( |
53
|
|
|
"by_alias", True |
54
|
|
|
) # whether field aliases should be used as keys in the returned dictionary |
55
|
|
|
|
56
|
|
|
# Converting the model to a dictionnary |
57
|
|
|
parsed = self.dict(by_alias=by_alias, exclude_unset=exclude_unset, **kwargs) |
58
|
|
|
|
59
|
|
|
# Mongo uses `_id` as default key. |
60
|
|
|
# if "_id" not in parsed and "id" in parsed: |
61
|
|
|
# parsed["_id"] = parsed.pop("id") |
62
|
|
|
|
63
|
|
|
if "id" in parsed: |
64
|
|
|
parsed.pop("id") |
65
|
|
|
|
66
|
|
|
return parsed |
67
|
|
|
|
68
|
|
|
def dict(self, **kwargs): |
69
|
|
|
"""Override self.dict to hide some fields that are used as metadata""" |
70
|
|
|
hidden_fields = {"_collection"} |
71
|
|
|
kwargs.setdefault("exclude", hidden_fields) |
72
|
|
|
return super().dict(**kwargs) |
73
|
|
|
|