countries.CountriesUpdater.__init__()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 21
Code Lines 16

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 16
dl 0
loc 21
rs 9.6
c 0
b 0
f 0
cc 2
nop 1
1
from io import BytesIO
2
from typing import ClassVar, Final, final
3
4
from lxml import etree
5
6
# future: report mandatory usage of protected member to lxml developers at https://bugs.launchpad.net/lxml
7
#  https://github.com/lxml/lxml/blob/a4a78214506409e5bbb6c4249cac0c0ca6479d3e/src/lxml/etree.pyx#L1877
8
#  https://github.com/lxml/lxml/blob/a4a78214506409e5bbb6c4249cac0c0ca6479d3e/src/lxml/etree.pyx#L3166
9
# noinspection PyProtectedMember
10
from lxml.etree import XMLSchema, XMLSyntaxError, _Element
11
from pandas import DataFrame
12
import pandas as pd
13
from requests import HTTPError
14
from sqlalchemy import Column, MetaData, SmallInteger, String, Table, text
15
from zipfile import ZipFile
16
17
from src.new_data_processors.common import UICTableUpdater
18
19
20
def _uic_code_not_assigned(values: tuple[str, str, str]) -> bool:
21
    return values[1] is None
22
23
24
def _swap_name(name: str) -> str:
25
    return " ".join(name.split(", ")[::-1])
26
27
28
@final
29
class CountriesUpdater(UICTableUpdater):
30
    TABLE_NAME: ClassVar[str] = "countries"
31
    database_metadata: ClassVar[MetaData] = MetaData()
32
33
    table: ClassVar[Table] = Table(
34
        TABLE_NAME,
35
        database_metadata,
36
        Column(name="code_iso", type_=String(2), nullable=False, index=True),
37
        Column(name="code_uic", type_=SmallInteger, nullable=False, primary_key=True),
38
        Column(name="name_en", type_=String(255), nullable=False),
39
        Column(name="name_fr", type_=String(255)),
40
        Column(name="name_de", type_=String(255)),
41
    )
42
43
    def __init__(self) -> None:
44
        super().__init__()
45
46
        self._data_to_validate: _Element = NotImplemented
47
        self.namespace: dict = NotImplemented
48
49
        self.DATA_URL = f"{self.DATA_BASE_URL}3984"
50
        self._TAG_ROW: Final = "ns1:Country"
51
        self._PATH_ROW: Final = f".//{self._TAG_ROW}"
52
        self._TAG_BEGINNING_COLUMN: Final = f"{self._TAG_ROW}_"
53
        self.XSD_URL: Final = f"{self.DATA_BASE_URL}320"
54
55
        self._data_to_process = self.get_data(self.DATA_URL)
56
57
        try:
58
            self.xsd_to_process: Final = self.get_data(self.XSD_URL)
59
            self._xsd: Final[etree.XMLSchema] = self.process_xsd()
60
        except (HTTPError, IndexError) as exception:
61
            self.logger.warning(exception)
62
63
        self.logger.info(f"{self.__class__.__name__} initialized!")
64
65
    def process_xsd(self) -> XMLSchema:
66
        xsd_unzipped = self.unzip(self.xsd_to_process)
67
        return XMLSchema(etree.parse(xsd_unzipped))
68
69
    def unzip(self, xsd: bytes) -> BytesIO:
70
        with ZipFile(BytesIO(xsd), "r") as zipped_file:
71
            file_names = zipped_file.infolist()
72
            if len(file_names) > 1:
73
                raise IndexError(
74
                    f"The .zip file downloaded from {self.XSD_URL} has more than one file in it!"
75
                )
76
            only_file_name = file_names[0]
77
            return BytesIO(zipped_file.read(only_file_name))
78
79
    def process_data(self) -> None:
80
        try:
81
            self.validate_data()
82
        except ValueError as exception:
83
            self.logger.critical(exception)
84
            raise
85
86
        self.data = self.read_data_from_xml()
87
88
        self.rename_columns_manually()
89
        self.drop_unnecessary_columns()
90
        self.swap_names_separated_with_comma()
91
92
    def validate_data(self) -> None:
93
        try:
94
            parsed_data = etree.parse(BytesIO(self._data_to_process))
95
            self.logger.debug(
96
                f"Data downloaded from {self.DATA_URL} successfully parsed!"
97
            )
98
            self._data_to_validate = parsed_data.getroot()
99
        except XMLSyntaxError:
100
            self._data_to_process = self.remove_first_line(self._data_to_process)
101
            self.validate_data()
102
103
        self.namespace = self._data_to_validate.nsmap
104
105
        if not self.is_data_valid():
106
            raise ValueError(
107
                f"The .xml file downloaded from {self.DATA_URL} is invalid "
108
                f"according to the .xsd downloaded and unzipped from {self.XSD_URL}!"
109
            )
110
        self.logger.debug(
111
            f"Data downloaded from {self.DATA_URL} successfully validated!"
112
        )
113
114
    def remove_first_line(self, data: bytes) -> bytes:
115
        try:
116
            lines = data.split(b"\n", 1)
117
            return lines[1]
118
        finally:
119
            self.logger.debug(
120
                f"First line removed from data downloaded from {self.DATA_URL}!"
121
            )
122
123
    def is_data_valid(self) -> bool:
124
        if self._xsd and self._xsd.validate(self._data_to_validate):
125
            return True
126
        else:
127
            return False
128
129
    def read_data_from_xml(self) -> DataFrame:
130
        # future: report wrong documentation URL of pd.read_xml() to JetBrains or pandas developers
131
        return pd.read_xml(
132
            path_or_buffer=BytesIO(self._data_to_process),
133
            xpath=self._PATH_ROW,
134
            namespaces=self.namespace,
135
        )
136
137
    def rename_columns_manually(self) -> None:
138
        self.data.rename(
139
            columns={
140
                "Country_ISO_Code": "code_iso",
141
                "Country_UIC_Code": "code_uic",
142
                "Country_Name_EN": "name_en",
143
                "Country_Name_FR": "name_fr",
144
                "Country_Name_DE": "name_de",
145
            },
146
            inplace=True,
147
        )
148
149
    def drop_unnecessary_columns(self) -> None:
150
        self.data.dropna(
151
            subset=["code_uic"],
152
            inplace=True,
153
        )
154
155
    def swap_names_separated_with_comma(self) -> None:
156
        columns_to_swap = [
157
            "name_en",
158
            "name_fr",
159
            "name_de",
160
        ]
161
        self.data["name_en"] = self.data["name_en"].apply(lambda x: x.rstrip())
162
        for column_name in columns_to_swap:
163
            self.data[column_name] = self.data[column_name].apply(
164
                lambda x: _swap_name(x)
165
            )
166
167
    def add_data(self) -> None:
168
        with self.database.engine.begin() as connection:
169
            for index, row in self.data.iterrows():
170
                query = """
171
                insert ignore into countries (
172
                    code_iso,
173
                    code_uic,
174
                    name_en,
175
                    name_fr,
176
                    name_de
177
                )
178
                values (
179
                    :code_iso,
180
                    :code_uic,
181
                    :name_en,
182
                    :name_fr,
183
                    :name_de
184
                )
185
                """
186
                connection.execute(
187
                    text(query),
188
                    row.to_dict(),
189
                )
190
191
        self.logger.info(
192
            f"Successfully added new data downloaded from {self.DATA_URL} to table `countries`!"
193
        )
194