Passed
Pull Request — main (#6)
by Mohammad
04:13
created

MetaClf.__init__()   A

Complexity

Conditions 3

Size

Total Lines 22
Code Lines 19

Duplication

Lines 22
Ratio 100 %

Code Coverage

Tests 9
CRAP Score 3.7536

Importance

Changes 0
Metric Value
eloc 19
dl 22
loc 22
ccs 9
cts 16
cp 0.5625
rs 9.45
c 0
b 0
f 0
cc 3
nop 6
crap 3.7536
1
"""
2
Meta Class for Classifiers
3
4
....................................................................................................
5
MIT License
6
Copyright (c) 2021-2023 AUT Iran, Mohammad H Forouhesh
7
Copyright (c) 2021-2022 MetoData.ai, Mohammad H Forouhesh
8
....................................................................................................
9
This module abstracts classifiers.
10
"""
11
12 1
import os
13 1
import pickle
14 1
from typing import List
15
16 1
import numpy as np
0 ignored issues
show
Unused Code introduced by
Unused numpy imported as np
Loading history...
17 1
import pandas as pd
18 1
from tqdm import tqdm
19 1
from sklearn.preprocessing import MinMaxScaler
20 1
from sklearn.metrics import classification_report
21 1
from sklearn.model_selection import train_test_split
22
23 1
from ..api import get_resources
24 1
from ..preprocess.preprocessing import remove_redundant_characters, remove_emoji
25 1
from ..word2vec.w2v_emb import W2VEmb
26
27
28 1 View Code Duplication
class MetaClf:
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
introduced by
Missing class docstring
Loading history...
29 1
    def __init__(self, classifier_instance, text_array: list = None, embedding_doc: list = None, labels: list = None, load_path: str = None):
0 ignored issues
show
best-practice introduced by
Too many arguments (6/5)
Loading history...
Coding Style introduced by
This line is too long as per the coding-style (141/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
30 1
        if not isinstance(text_array, pd.Series): text_array = pd.Series(text_array)
0 ignored issues
show
Coding Style introduced by
More than one statement on a single line
Loading history...
31
32 1
        self.clf = classifier_instance
33 1
        self.emb = W2VEmb()
34 1
        self.scaler = None
35 1
        self.dir_path = os.path.dirname(
36
            os.path.dirname(
37
                os.path.dirname(
38
                    os.path.realpath(__file__)))) + "/"
39 1
        if load_path is not None:
40 1
            get_resources(self.dir_path, resource_name=load_path)
41 1
            self.load_model(load_path)
42
        else:
43
            assert text_array is not None and labels is not None
44
            text_array.fillna('', inplace=True)
45
            self.emb = W2VEmb(embedding_doc)
46
47
            encoded = list(map(self.emb.encode, tqdm(text_array)))
48
            self.labels = list(labels)
49
            self.scaler = self.prep_scaler(encoded)
50
            self.encoded_input = self.scaler.transform(encoded)
51
52 1
    def prep_scaler(self, encoded: List[int]) -> MinMaxScaler:
0 ignored issues
show
Coding Style introduced by
This method could be written as a function/class method.

If a method does not access any attributes of the class, it could also be implemented as a function or static method. This can help improve readability. For example

class Foo:
    def some_method(self, x, y):
        return x + y;

could be written as

class Foo:
    @classmethod
    def some_method(cls, x, y):
        return x + y;
Loading history...
53
        """
54
        Fitting a Min-Max Scaler to use in the pipeline
55
        :param encoded:     An array of numbers.
56
        :return:            A MinMaxScaler
57
        """
58
        scaler = MinMaxScaler()
59
        scaler.fit(encoded)
60
        return scaler
61
62 1
    def fit(self):
0 ignored issues
show
introduced by
Missing function or method docstring
Loading history...
63
        X_train, X_test, y_train, y_test = train_test_split(self.encoded_input, self.labels, test_size=0.2,
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (107/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
Coding Style Naming introduced by
Variable name "X_train" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
Coding Style Naming introduced by
Variable name "X_test" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
64
                                                            random_state=42, stratify=self.labels)
65
        self.clf.fit(X_train, y_train)
66
        print('score: ', self.clf.score(X_test, y_test))
67
        print('============================trian============================')
68
        print(classification_report(y_train, self.clf.predict(X_train)))
69
        print('=============================test============================')
70
        print(classification_report(y_test, self.clf.predict(X_test)))
71
        return self.clf
72
73 1
    def load_model(self, load_path: str) -> None:
74
        """
75
        A tool to load model from disk.
76
        :param load_path:   Model path.
77
        :return:            None
78
        """
79
80 1
        loading_prep = lambda string: f'model_dir/{load_path}/{string}'
81 1
        self.clf.load_model(loading_prep('model.json'))
82 1
        self.emb.load(loading_prep('emb.pkl'))
83 1
        with open(loading_prep('scaler.pkl'), 'rb') as f:
0 ignored issues
show
Coding Style Naming introduced by
Variable name "f" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
84 1
            self.scaler = pickle.load(f)
85
86 1
    def save_model(self, save_path: str):
87
        """
88
        A tool to save model to disk
89
        :param save_path:   Saving path.
90
        :return:            None.
91
        """
92
        os.makedirs(f'model_dir/{save_path}', exist_ok=True)
93
        saving_prep = lambda string: f'model_dir/{save_path}/{string}'
94
        self.clf.save_model(saving_prep('model.json'))
95
        self.emb.save(saving_prep('emb.pkl'))
96
        with open(saving_prep('scaler.pkl'), 'wb') as f:
0 ignored issues
show
Coding Style Naming introduced by
Variable name "f" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
97
            pickle.dump(self.scaler, f, pickle.HIGHEST_PROTOCOL)
98
99 1
    def __getitem__(self, item: str) -> int:
100
        """
101
        getitem overwritten
102
        :param item:    Input text
103
        :return:        Predicted class (0, 1).
104
        """
105 1
        return self.predict(item)
106
107 1
    def predict(self, input_text: str) -> int:
108
        """
109
        Prediction method.
110
        :param input_text:  input text, string
111
        :return:            predicted class. (0, 1)
112
        """
113 1
        prep_text = remove_redundant_characters(remove_emoji(input_text))
114 1
        vector = self.scaler.transform(self.emb.encode(prep_text).reshape(1, -1))
115
        return self.clf.predict(vector)[0]
116