Passed
Pull Request — master (#87)
by Simon
01:22
created

ObjectiveFunctionAdapter.add_validation()   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 3
dl 0
loc 3
rs 10
c 0
b 0
f 0
1
# Author: Simon Blanke
2
# Email: [email protected]
3
# License: MIT License
4
5
6
from sklearn.model_selection import cross_validate
7
from sklearn.utils.validation import _num_samples
8
9
10
class ObjectiveFunctionAdapter:
11
    def __init__(self, estimator) -> None:
12
        self.estimator = estimator
13
14
    def add_dataset(self, X, y):
15
        self.X = X
16
        self.y = y
17
18
    def add_validation(self, scoring, cv):
19
        self.scoring = scoring
20
        self.cv = cv
21
22
    def objective_function(self, params):
23
        cv_results = cross_validate(
24
            self.estimator,
25
            self.X,
26
            self.y,
27
            cv=self.cv,
28
        )
29
30
        add_info_d = {
31
            "score_time": cv_results["score_time"],
32
            "fit_time": cv_results["fit_time"],
33
            "n_test_samples": _num_samples(self.X),
34
        }
35
36
        return cv_results["test_score"].mean(), add_info_d
37