1
|
|
|
import pandas as pd |
2
|
|
|
from tqdm import tqdm |
3
|
|
|
from tracking_policy_agendas.classifiers.pa_clf import PAClf |
4
|
|
|
from tracking_policy_agendas.classifiers.xgb_clf import XgbClf |
5
|
|
|
from tracking_policy_agendas.classifiers.naive_bayes_clf import GNBClf |
6
|
|
|
from tracking_policy_agendas.classifiers.lasso_clf import LassoClf |
7
|
|
|
from tracking_policy_agendas.preprocess.preprocessing import remove_emoji, remove_redundant_characters |
8
|
|
|
|
9
|
|
|
tqdm.pandas() |
10
|
|
|
|
11
|
|
|
|
12
|
|
|
def inference_pipeline(model_path: str, input_text: str): |
13
|
|
|
xgb = XgbClf(text_array=None, labels=None, load_path=model_path) |
14
|
|
|
return xgb.predict(input_text) |
15
|
|
|
|
16
|
|
|
|
17
|
|
|
def main(embedding_frame:pd.DataFrame, dataframe: pd.DataFrame, save_path: str): |
18
|
|
|
xgb = XgbClf(text_array=dataframe.prep_text, labels=dataframe.label, embedding_doc=embedding_frame.prep_text) |
19
|
|
|
xgb.fit() |
20
|
|
|
xgb.save_model('xgb_' + save_path) |
21
|
|
|
pa = PAClf(text_array=dataframe.prep_text, labels=dataframe.label, embedding_doc=embedding_frame.prep_text) |
22
|
|
|
pa.fit() |
23
|
|
|
pa.save_model('pa_' + save_path) |
24
|
|
|
lasso = LassoClf(text_array=dataframe.prep_text, labels=dataframe.label, embedding_doc=embedding_frame.prep_text) |
25
|
|
|
lasso.fit() |
26
|
|
|
lasso.save_model('lasso_' + save_path) |
27
|
|
|
gnb = GNBClf(text_array=dataframe.prep_text, labels=dataframe.label, embedding_doc=embedding_frame.prep_text) |
28
|
|
|
gnb.fit() |
29
|
|
|
gnb.save_model('gnb_' + save_path) |
30
|
|
|
|
31
|
|
|
|
32
|
|
|
if __name__ == '__main__': |
33
|
|
|
df = pd.read_excel('jcpoa_sampling.xlsx')[['text', 'prep_text', 'label']] |
34
|
|
|
emb_df = df |
35
|
|
|
df['prep_text'] = df.prep_text.progress_apply(lambda item: remove_redundant_characters(remove_emoji(item))) |
36
|
|
|
df = df.replace('', float('NaN')).dropna() |
37
|
|
|
emb_df['prep_text'] = emb_df.prep_text.progress_apply(lambda item: remove_redundant_characters(remove_emoji(item))) |
38
|
|
|
emb_df = emb_df.replace('', float('NaN')).dropna() |
39
|
|
|
main(emb_df, df, 'jcpoa') |