Passed
Pull Request — master (#110)
by
unknown
01:31
created

SklearnCvExperiment.__init__()   A

Complexity

Conditions 1

Size

Total Lines 6
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 6
nop 6
dl 0
loc 6
rs 10
c 0
b 0
f 0
1
"""Experiment adapter for sklearn cross-validation experiments."""
2
3
from sklearn import clone
4
from sklearn.model_selection import cross_validate
5
from sklearn.utils.validation import _num_samples
6
7
from hyperactive.base import BaseExperiment
8
9
class SklearnCvExperiment(BaseExperiment):
10
    """Experiment adapter for sklearn cross-validation experiments.
11
12
    This class is used to perform cross-validation experiments using a given
13
    sklearn estimator. It allows for hyperparameter tuning and evaluation of
14
    the model's performance using cross-validation.
15
16
    The score returned is the mean of the cross-validation scores,
17
    of applying cross-validation to ``estimator`` with the parameters given in
18
    ``score`` ``params``.
19
20
    The cross-validation performed is specified by the ``cv`` parameter,
21
    and the scoring metric is specified by the ``scoring`` parameter.
22
    The ``X`` and ``y`` parameters are the input data and target values,
23
    which are used in fit/predict cross-validation.
24
25
    Parameters
26
    ----------
27
    estimator : sklearn estimator
28
        The estimator to be used for the experiment.
29
    scoring : callable or str
30
        sklearn scoring function or metric to evaluate the model's performance.
31
    cv : int or cross-validation generator
32
        The number of folds or cross-validation strategy to be used.
33
    X : array-like, shape (n_samples, n_features)
34
            The input data for the model.
35
    y : array-like, shape (n_samples,) or (n_samples, n_outputs)
36
        The target values for the model.
37
    """
38
39
    def __init__(self, estimator, scoring, cv, X, y):
40
        self.estimator = estimator
41
        self.X = X
42
        self.y = y
43
        self.scoring = scoring
44
        self.cv = cv
45
46
    def _paramnames(self):
47
        """Return the parameter names of the search.
48
49
        Returns
50
        -------
51
        list of str
52
            The parameter names of the search parameters.
53
        """
54
        return list(self.estimator.get_params().keys())
55
56
    def _score(self, **params):
57
        """Score the parameters.
58
59
        Parameters
60
        ----------
61
        params : dict with string keys
62
            Parameters to score.
63
64
        Returns
65
        -------
66
        float
67
            The score of the parameters.
68
        dict
69
            Additional metadata about the search.
70
        """
71
        estimator = clone(self.estimator)
72
        estimator.set_params(**params)
73
74
        cv_results = cross_validate(
75
            estimator,
76
            self.X,
77
            self.y,
78
            cv=self.cv,
79
        )
80
81
        add_info_d = {
82
            "score_time": cv_results["score_time"],
83
            "fit_time": cv_results["fit_time"],
84
            "n_test_samples": _num_samples(self.X),
85
        }
86
87
        return cv_results["test_score"].mean(), add_info_d
88