main.main()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 13
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 13
rs 9.75
c 0
b 0
f 0
cc 1
nop 3
1
import pandas as pd
2
from tqdm import tqdm
3
from tracking_policy_agendas.classifiers.pa_clf import PAClf
4
from tracking_policy_agendas.classifiers.xgb_clf import XgbClf
5
from tracking_policy_agendas.classifiers.naive_bayes_clf import GNBClf
6
from tracking_policy_agendas.classifiers.lasso_clf import LassoClf
7
from tracking_policy_agendas.preprocess.preprocessing import remove_emoji, remove_redundant_characters
8
9
tqdm.pandas()
10
11
12
def inference_pipeline(model_path: str, input_text: str):
13
    xgb = XgbClf(text_array=None, labels=None, load_path=model_path)
14
    return xgb.predict(input_text)
15
16
17
def main(embedding_frame:pd.DataFrame, dataframe: pd.DataFrame, save_path: str):
18
    xgb = XgbClf(text_array=dataframe.prep_text, labels=dataframe.label, embedding_doc=embedding_frame.prep_text)
19
    xgb.fit()
20
    xgb.save_model('xgb_' + save_path)
21
    pa = PAClf(text_array=dataframe.prep_text, labels=dataframe.label, embedding_doc=embedding_frame.prep_text)
22
    pa.fit()
23
    pa.save_model('pa_' + save_path)
24
    lasso = LassoClf(text_array=dataframe.prep_text, labels=dataframe.label, embedding_doc=embedding_frame.prep_text)
25
    lasso.fit()
26
    lasso.save_model('lasso_' + save_path)
27
    gnb = GNBClf(text_array=dataframe.prep_text, labels=dataframe.label, embedding_doc=embedding_frame.prep_text)
28
    gnb.fit()
29
    gnb.save_model('gnb_' + save_path)
30
31
32
if __name__ == '__main__':
33
    df = pd.read_excel('jcpoa_sampling.xlsx')[['text', 'prep_text', 'label']]
34
    emb_df = df
35
    df['prep_text'] = df.prep_text.progress_apply(lambda item: remove_redundant_characters(remove_emoji(item)))
36
    df = df.replace('', float('NaN')).dropna()
37
    emb_df['prep_text'] = emb_df.prep_text.progress_apply(lambda item: remove_redundant_characters(remove_emoji(item)))
38
    emb_df = emb_df.replace('', float('NaN')).dropna()
39
    main(emb_df, df, 'jcpoa')