MetaClf.fit()   A
last analyzed

Complexity

Conditions 1

Size

Total Lines 10
Code Lines 10

Duplication

Lines 10
Ratio 100 %

Code Coverage

Tests 1
CRAP Score 1.7023

Importance

Changes 0
Metric Value
eloc 10
dl 10
loc 10
ccs 1
cts 9
cp 0.1111
rs 9.9
c 0
b 0
f 0
cc 1
nop 1
crap 1.7023
1
"""
2
Meta Class for Classifiers
3
4
....................................................................................................
5
MIT License
6
Copyright (c) 2021-2023 AUT Iran, Mohammad H Forouhesh
7
Copyright (c) 2021-2022 MetoData.ai, Mohammad H Forouhesh
8
....................................................................................................
9
This module abstracts classifiers.
10
"""
11
12 1
import os
13 1
import pickle
14 1
from typing import List, Union
15
16 1
import numpy as np
17 1
import pandas as pd
18 1
from tqdm import tqdm
19 1
from sklearn.preprocessing import MinMaxScaler
20 1
from sklearn.metrics import classification_report
21 1
from sklearn.model_selection import train_test_split
22
23 1
from ..api import get_resources
24 1
from ..preprocess.preprocessing import remove_redundant_characters, remove_emoji
25 1
from ..word2vec.w2v_emb import W2VEmb
26
27
28 1 View Code Duplication
class MetaClf:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing class docstring
Loading history...
29 1
    def __init__(self, classifier_instance, text_array: Union[List[str], pd.Series] = None, embedding_doc: list = None,
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (119/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
best-practice introduced by
Too many arguments (6/5)
Loading history...
30
                 labels: list = None, load_path: str = None):
31 1
        if not isinstance(text_array, pd.Series): text_array = pd.Series(text_array)
0 ignored issues
show
Coding Style introduced by
More than one statement on a single line
Loading history...
32
33 1
        self.clf = classifier_instance
34 1
        self.emb = W2VEmb()
35 1
        self.scaler = None
36 1
        self.dir_path = os.path.dirname(
37
            os.path.dirname(
38
                os.path.dirname(
39
                    os.path.realpath(__file__)))) + "/"
40 1
        if load_path is not None:
41 1
            get_resources(self.dir_path, resource_name=load_path)
42 1
            self.load_model(load_path)
43
        else:
44
            assert text_array is not None and labels is not None
45
            text_array.fillna('', inplace=True)
46
            self.emb = W2VEmb(embedding_doc)
47
48
            encoded = list(map(self.emb.encode, tqdm(text_array)))
49
            self.labels = list(labels)
50
            self.scaler = self.prep_scaler(encoded)
51
            self.encoded_input = self.scaler.transform(encoded)
52
53 1
    def prep_scaler(self, encoded: List[np.ndarray]) -> MinMaxScaler:
0 ignored issues
show
Coding Style introduced by
This method could be written as a function/class method.

If a method does not access any attributes of the class, it could also be implemented as a function or static method. This can help improve readability. For example

class Foo:
    def some_method(self, x, y):
        return x + y;

could be written as

class Foo:
    @classmethod
    def some_method(cls, x, y):
        return x + y;
Loading history...
54
        """
55
        Fitting a Min-Max Scaler to use in the pipeline
56
        :param encoded:     An array of numbers.
57
        :return:            A MinMaxScaler
58
        """
59
        scaler = MinMaxScaler()
60
        scaler.fit(encoded)
61
        return scaler
62
63 1
    def fit(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
64
        X_train, X_test, y_train, y_test = train_test_split(self.encoded_input, self.labels, test_size=0.2,
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (107/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
Coding Style Naming introduced by
Variable name "X_train" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
Coding Style Naming introduced by
Variable name "X_test" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
65
                                                            random_state=42, stratify=self.labels)
66
        self.clf.fit(X_train, y_train)
67
        print('score: ', self.clf.score(X_test, y_test))
68
        print('============================trian============================')
69
        print(classification_report(y_train, self.clf.predict(X_train)))
70
        print('=============================test============================')
71
        print(classification_report(y_test, self.clf.predict(X_test)))
72
        return self.clf
73
74 1
    def load_model(self, load_path: str) -> None:
75
        """
76
        A tool to load model from disk.
77
        :param load_path:   Model path.
78
        :return:            None
79
        """
80
81 1
        loading_prep = lambda string: f'model_dir/{load_path}/{string}'
82 1
        self.clf.load_model(loading_prep('model.json'))
83 1
        self.emb.load(loading_prep('emb.pkl'))
84 1
        with open(loading_prep('scaler.pkl'), 'rb') as f:
0 ignored issues
show
Coding Style Naming introduced by
Variable name "f" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
85 1
            self.scaler = pickle.load(f)
86
87 1
    def save_model(self, save_path: str):
88
        """
89
        A tool to save model to disk
90
        :param save_path:   Saving path.
91
        :return:            None.
92
        """
93
        os.makedirs(f'model_dir/{save_path}', exist_ok=True)
94
        saving_prep = lambda string: f'model_dir/{save_path}/{string}'
95
        self.clf.save_model(saving_prep('model.json'))
96
        self.emb.save(saving_prep('emb.pkl'))
97
        with open(saving_prep('scaler.pkl'), 'wb') as f:
0 ignored issues
show
Coding Style Naming introduced by
Variable name "f" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
98
            pickle.dump(self.scaler, f, pickle.HIGHEST_PROTOCOL)
99
100 1
    def __getitem__(self, item: str) -> int:
101
        """
102
        getitem overwritten
103
        :param item:    Input text
104
        :return:        Predicted class (0, 1).
105
        """
106 1
        return self.predict(item)
107
108 1
    def predict(self, input_text: str) -> int:
109
        """
110
        Prediction method.
111
        :param input_text:  input text, string
112
        :return:            predicted class. (0, 1)
113
        """
114 1
        prep_text = remove_redundant_characters(remove_emoji(input_text))
115 1
        vector = self.scaler.transform(self.emb.encode(prep_text).reshape(1, -1))
116
        return self.clf.predict(vector)[0]
117