Test Failed
Pull Request — main (#10)
by Mohammad
07:44
created

MetaClf.vec_predict()   A

Complexity

Conditions 1

Size

Total Lines 9
Code Lines 4

Duplication

Lines 9
Ratio 100 %

Code Coverage

Tests 3
CRAP Score 1

Importance

Changes 0
Metric Value
eloc 4
dl 9
loc 9
ccs 3
cts 3
cp 1
rs 10
c 0
b 0
f 0
cc 1
nop 2
crap 1
1
"""
2
Meta Class for Classifiers
3
4
....................................................................................................
5
MIT License
6
Copyright (c) 2021-2023 AUT Iran, Mohammad H Forouhesh
7
Copyright (c) 2021-2022 MetoData.ai, Mohammad H Forouhesh
8
....................................................................................................
9
This module abstracts classifiers.
10
"""
11
12 1
import os
13 1
import pickle
14 1
from typing import List, Union
0 ignored issues
show
Unused Code introduced by
Unused Union imported from typing
Loading history...
15
16 1
import numpy as np
17 1
import pandas as pd
0 ignored issues
show
Unused Code introduced by
Unused pandas imported as pd
Loading history...
18 1
from tqdm import tqdm
19 1
from sklearn.preprocessing import MinMaxScaler
20 1
from sklearn.metrics import classification_report
21 1
22
from ..api import get_resources
23 1
from ..preprocess.preprocessing import remove_redundant_characters, remove_emoji
24 1
from ..word2vec.w2v_emb import W2VEmb
25 1
26
27 View Code Duplication
class MetaClf:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing class docstring
Loading history...
28 1
    def __init__(self, classifier_instance, embedding_doc: list = None, load_path: str = None):
29 1
        self.clf = classifier_instance
30
        self.emb = W2VEmb()
31 1
        self.scaler = None
32
        self.dir_path = os.path.dirname(
33 1
            os.path.dirname(
34 1
                os.path.dirname(
35 1
                    os.path.realpath(__file__)))) + "/"
36 1
        if load_path is not None:
37
            get_resources(self.dir_path, resource_name=load_path)
38
            self.load_model(load_path)
39
        else:
40 1
            self.emb = W2VEmb(embedding_doc)
41 1
42 1
    def prep_scaler(self, encoded: List[np.ndarray]) -> MinMaxScaler:
0 ignored issues
show
Coding Style introduced by
This method could be written as a function/class method.

If a method does not access any attributes of the class, it could also be implemented as a function or static method. This can help improve readability. For example

class Foo:
    def some_method(self, x, y):
        return x + y;

could be written as

class Foo:
    @classmethod
    def some_method(cls, x, y):
        return x + y;
Loading history...
43
        """
44
        Fitting a Min-Max Scaler to use in the pipeline
45
        :param encoded:     An array of numbers.
46
        :return:            A MinMaxScaler
47
        """
48
        scaler = MinMaxScaler()
49
        scaler.fit(encoded)
50
        return scaler
51
52
    def fit(self, X_train, y_train):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
Coding Style Naming introduced by
Argument name "X_train" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
53 1
        encoded = list(map(self.emb.encode, tqdm(X_train)))
54
        self.scaler = self.prep_scaler(encoded)
55
        self.clf.fit(self.scaler.transform(encoded), list(y_train))
56
        print('============================trian============================')
57
        print(classification_report(y_train, self.clf.predict(self.scaler.transform(encoded))))
58
        return self.clf
59
60
    def predict(self, X_test, y_test):
0 ignored issues
show
Coding Style Naming introduced by
Argument name "X_test" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
introduced by
Missing function or method docstring
Loading history...
61
        encoded = list(map(self.emb.encode, tqdm(X_test)))
62
        print('score: ', self.clf.score(self.scaler.transform(encoded), list(y_test)))
63 1
        print('=============================test============================')
64
        print(classification_report(y_test, self.clf.predict(self.scaler.transform(encoded))))
65
66
    def load_model(self, load_path: str) -> None:
67
        """
68
        A tool to load model from disk.
69
        :param load_path:   Model path.
70
        :return:            None
71
        """
72
73
        loading_prep = lambda string: f'model_dir/{load_path}/{string}'
74 1
        self.clf.load_model(loading_prep('model.json'))
75
        self.emb.load(loading_prep('emb.pkl'))
76
        with open(loading_prep('scaler.pkl'), 'rb') as f:
0 ignored issues
show
Coding Style Naming introduced by
Variable name "f" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
77
            self.scaler = pickle.load(f)
78
79
    def save_model(self, save_path: str):
80
        """
81 1
        A tool to save model to disk
82 1
        :param save_path:   Saving path.
83 1
        :return:            None.
84 1
        """
85 1
        os.makedirs(f'model_dir/{save_path}', exist_ok=True)
86
        saving_prep = lambda string: f'model_dir/{save_path}/{string}'
87 1
        self.clf.save_model(saving_prep('model.json'))
88
        self.emb.save(saving_prep('emb.pkl'))
89
        with open(saving_prep('scaler.pkl'), 'wb') as f:
0 ignored issues
show
Coding Style Naming introduced by
Variable name "f" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
90
            pickle.dump(self.scaler, f, pickle.HIGHEST_PROTOCOL)
91
92
    def __getitem__(self, item: str) -> int:
93
        """
94
        getitem overwritten
95
        :param item:    Input text
96
        :return:        Predicted class (0, 1).
97
        """
98
        return self.vec_predict(item)
99
100 1
    def vec_predict(self, input_text: str) -> int:
101
        """
102
        Prediction method.
103
        :param input_text:  input text, string
104
        :return:            predicted class. (0, 1)
105
        """
106 1
        prep_text = remove_redundant_characters(remove_emoji(input_text))
107
        vector = self.scaler.transform(self.emb.encode(prep_text).reshape(1, -1))
108
        return self.clf.predict(vector)[0]
109