| 1 |  |  | """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 2 |  |  | Meta Class for Classifiers | 
            
                                                                                                            
                            
            
                                    
            
            
                | 3 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 4 |  |  | .................................................................................................... | 
            
                                                                                                            
                            
            
                                    
            
            
                | 5 |  |  | MIT License | 
            
                                                                                                            
                            
            
                                    
            
            
                | 6 |  |  | Copyright (c) 2021-2023 AUT Iran, Mohammad H Forouhesh | 
            
                                                                                                            
                            
            
                                    
            
            
                | 7 |  |  | Copyright (c) 2021-2022 MetoData.ai, Mohammad H Forouhesh | 
            
                                                                                                            
                            
            
                                    
            
            
                | 8 |  |  | .................................................................................................... | 
            
                                                                                                            
                            
            
                                    
            
            
                | 9 |  |  | This module abstracts classifiers. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 10 |  |  | """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 11 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 12 | 1 |  | import os | 
            
                                                                                                            
                            
            
                                    
            
            
                | 13 | 1 |  | import pickle | 
            
                                                                                                            
                            
            
                                    
            
            
                | 14 | 1 |  | from typing import List, Union | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 15 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 16 | 1 |  | import numpy as np | 
            
                                                                                                            
                            
            
                                    
            
            
                | 17 | 1 |  | import pandas as pd | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 18 | 1 |  | from tqdm import tqdm | 
            
                                                                                                            
                            
            
                                    
            
            
                | 19 | 1 |  | from sklearn.preprocessing import MinMaxScaler | 
            
                                                                                                            
                            
            
                                    
            
            
                | 20 | 1 |  | from sklearn.metrics import classification_report | 
            
                                                                                                            
                            
            
                                    
            
            
                | 21 | 1 |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 22 |  |  | from ..api import get_resources | 
            
                                                                                                            
                            
            
                                    
            
            
                | 23 | 1 |  | from ..preprocess.preprocessing import remove_redundant_characters, remove_emoji | 
            
                                                                                                            
                            
            
                                    
            
            
                | 24 | 1 |  | from ..word2vec.w2v_emb import W2VEmb | 
            
                                                                                                            
                            
            
                                    
            
            
                | 25 | 1 |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 26 |  |  |  | 
            
                                                                                                            
                            
            
                                                                    
                                                                                                        
            
            
                | 27 |  | View Code Duplication | class MetaClf: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                            
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 28 | 1 |  |     def __init__(self, classifier_instance, embedding_doc: list = None, load_path: str = None): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 29 | 1 |  |         self.clf = classifier_instance | 
            
                                                                                                            
                            
            
                                    
            
            
                | 30 |  |  |         self.emb = W2VEmb() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 31 | 1 |  |         self.scaler = None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 32 |  |  |         self.dir_path = os.path.dirname( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 33 | 1 |  |             os.path.dirname( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 34 | 1 |  |                 os.path.dirname( | 
            
                                                                                                            
                            
            
                                    
            
            
                | 35 | 1 |  |                     os.path.realpath(__file__)))) + "/" | 
            
                                                                                                            
                            
            
                                    
            
            
                | 36 | 1 |  |         if load_path is not None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 37 |  |  |             get_resources(self.dir_path, resource_name=load_path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 38 |  |  |             self.load_model(load_path) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 39 |  |  |         else: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 40 | 1 |  |             self.emb = W2VEmb(embedding_doc) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 41 | 1 |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 42 | 1 |  |     def prep_scaler(self, encoded: List[np.ndarray]) -> MinMaxScaler: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 43 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 44 |  |  |         Fitting a Min-Max Scaler to use in the pipeline | 
            
                                                                                                            
                            
            
                                    
            
            
                | 45 |  |  |         :param encoded:     An array of numbers. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 46 |  |  |         :return:            A MinMaxScaler | 
            
                                                                                                            
                            
            
                                    
            
            
                | 47 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 48 |  |  |         scaler = MinMaxScaler() | 
            
                                                                                                            
                            
            
                                    
            
            
                | 49 |  |  |         scaler.fit(encoded) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 50 |  |  |         return scaler | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 51 |  |  |  | 
            
                                                                        
                            
            
                                    
            
            
                | 52 |  |  |     def fit(self, X_train, y_train): | 
                            
                    |  |  |  | 
                                                                                        
                                                                                            
                                                                                     | 
            
                                                                        
                            
            
                                    
            
            
                | 53 | 1 |  |         encoded = list(map(self.emb.encode, tqdm(X_train))) | 
            
                                                                        
                            
            
                                    
            
            
                | 54 |  |  |         self.scaler = self.prep_scaler(encoded) | 
            
                                                                        
                            
            
                                    
            
            
                | 55 |  |  |         self.clf.fit(self.scaler.transform(encoded), list(y_train)) | 
            
                                                                        
                            
            
                                    
            
            
                | 56 |  |  |         print('============================trian============================') | 
            
                                                                        
                            
            
                                    
            
            
                | 57 |  |  |         print(classification_report(y_train, self.clf.predict(self.scaler.transform(encoded)))) | 
            
                                                                        
                            
            
                                    
            
            
                | 58 |  |  |         return self.clf | 
            
                                                                                                            
                            
            
                                    
            
            
                | 59 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 60 |  |  |     def predict(self, X_test, y_test): | 
                            
                    |  |  |  | 
                                                                                        
                                                                                            
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 61 |  |  |         encoded = list(map(self.emb.encode, tqdm(X_test))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 62 |  |  |         print('score: ', self.clf.score(self.scaler.transform(encoded), list(y_test))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 63 | 1 |  |         print('=============================test============================') | 
            
                                                                                                            
                            
            
                                    
            
            
                | 64 |  |  |         print(classification_report(y_test, self.clf.predict(self.scaler.transform(encoded)))) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 65 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 66 |  |  |     def load_model(self, load_path: str) -> None: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 67 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 68 |  |  |         A tool to load model from disk. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 69 |  |  |         :param load_path:   Model path. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 70 |  |  |         :return:            None | 
            
                                                                                                            
                            
            
                                    
            
            
                | 71 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 72 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 73 |  |  |         loading_prep = lambda string: f'model_dir/{load_path}/{string}' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 74 | 1 |  |         self.clf.load_model(loading_prep('model.json')) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 75 |  |  |         self.emb.load(loading_prep('emb.pkl')) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 76 |  |  |         with open(loading_prep('scaler.pkl'), 'rb') as f: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 77 |  |  |             self.scaler = pickle.load(f) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 78 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 79 |  |  |     def save_model(self, save_path: str): | 
            
                                                                                                            
                            
            
                                    
            
            
                | 80 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 81 | 1 |  |         A tool to save model to disk | 
            
                                                                                                            
                            
            
                                    
            
            
                | 82 | 1 |  |         :param save_path:   Saving path. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 83 | 1 |  |         :return:            None. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 84 | 1 |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 85 | 1 |  |         os.makedirs(f'model_dir/{save_path}', exist_ok=True) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 86 |  |  |         saving_prep = lambda string: f'model_dir/{save_path}/{string}' | 
            
                                                                                                            
                            
            
                                    
            
            
                | 87 | 1 |  |         self.clf.save_model(saving_prep('model.json')) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 88 |  |  |         self.emb.save(saving_prep('emb.pkl')) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 89 |  |  |         with open(saving_prep('scaler.pkl'), 'wb') as f: | 
                            
                    |  |  |  | 
                                                                                        
                                                                                     | 
            
                                                                                                            
                            
            
                                    
            
            
                | 90 |  |  |             pickle.dump(self.scaler, f, pickle.HIGHEST_PROTOCOL) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 91 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 92 |  |  |     def __getitem__(self, item: str) -> int: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 93 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 94 |  |  |         getitem overwritten | 
            
                                                                                                            
                            
            
                                    
            
            
                | 95 |  |  |         :param item:    Input text | 
            
                                                                                                            
                            
            
                                    
            
            
                | 96 |  |  |         :return:        Predicted class (0, 1). | 
            
                                                                                                            
                            
            
                                    
            
            
                | 97 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 98 |  |  |         return self.vec_predict(item) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 99 |  |  |  | 
            
                                                                                                            
                            
            
                                    
            
            
                | 100 | 1 |  |     def vec_predict(self, input_text: str) -> int: | 
            
                                                                                                            
                            
            
                                    
            
            
                | 101 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 102 |  |  |         Prediction method. | 
            
                                                                                                            
                            
            
                                    
            
            
                | 103 |  |  |         :param input_text:  input text, string | 
            
                                                                                                            
                            
            
                                    
            
            
                | 104 |  |  |         :return:            predicted class. (0, 1) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 105 |  |  |         """ | 
            
                                                                                                            
                            
            
                                    
            
            
                | 106 | 1 |  |         prep_text = remove_redundant_characters(remove_emoji(input_text)) | 
            
                                                                                                            
                            
            
                                    
            
            
                | 107 |  |  |         vector = self.scaler.transform(self.emb.encode(prep_text).reshape(1, -1)) | 
            
                                                                                                            
                                                                
            
                                    
            
            
                | 108 |  |  |         return self.clf.predict(vector)[0] | 
            
                                                        
            
                                    
            
            
                | 109 |  |  |  |