Passed
Pull Request — master (#110)
by
unknown
12:02 queued 10:25
created

SklearnCvExperiment.get_test_params()   A

Complexity

Conditions 1

Size

Total Lines 56
Code Lines 21

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 21
dl 0
loc 56
rs 9.376
c 0
b 0
f 0
cc 1
nop 2

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Experiment adapter for sklearn cross-validation experiments."""
2
3
from sklearn import clone
4
from sklearn.model_selection import cross_validate
5
from sklearn.utils.validation import _num_samples
6
7
from hyperactive.base import BaseExperiment
8
9
class SklearnCvExperiment(BaseExperiment):
10
    """Experiment adapter for sklearn cross-validation experiments.
11
12
    This class is used to perform cross-validation experiments using a given
13
    sklearn estimator. It allows for hyperparameter tuning and evaluation of
14
    the model's performance using cross-validation.
15
16
    The score returned is the mean of the cross-validation scores,
17
    of applying cross-validation to ``estimator`` with the parameters given in
18
    ``score`` ``params``.
19
20
    The cross-validation performed is specified by the ``cv`` parameter,
21
    and the scoring metric is specified by the ``scoring`` parameter.
22
    The ``X`` and ``y`` parameters are the input data and target values,
23
    which are used in fit/predict cross-validation.
24
25
    Parameters
26
    ----------
27
    estimator : sklearn estimator
28
        The estimator to be used for the experiment.
29
    scoring : callable or str
30
        sklearn scoring function or metric to evaluate the model's performance.
31
    cv : int or cross-validation generator
32
        The number of folds or cross-validation strategy to be used.
33
    X : array-like, shape (n_samples, n_features)
34
            The input data for the model.
35
    y : array-like, shape (n_samples,) or (n_samples, n_outputs)
36
        The target values for the model.
37
38
    Example
39
    -------
40
    >>> from hyperactive.experiment.integrations import SklearnCvExperiment
41
    >>> from sklearn.datasets import load_iris
42
    >>> from sklearn.svm import SVC
43
    >>> from sklearn.metrics import accuracy_score
44
    >>> from sklearn.model_selection import KFold
45
    >>>
46
    >>> X, y = load_iris(return_X_y=True)
47
    >>>
48
    >>> sklearn_exp = SklearnCvExperiment(
49
    ...    estimator=SVC(),
50
    ...     scoring=accuracy_score,
51
    ...     cv=KFold(n_splits=3, shuffle=True),
52
    ...     X=X,
53
    ...     y=y,
54
    ... )
55
    >>> params = {"C": 1.0, "kernel": "linear"}
56
    >>> score, add_info = sklearn_exp._score(params)
57
    """
58
59
    def __init__(self, estimator, scoring, cv, X, y):
60
        self.estimator = estimator
61
        self.X = X
62
        self.y = y
63
        self.scoring = scoring
64
        self.cv = cv
65
66
        super().__init__()
67
68
    def _paramnames(self):
69
        """Return the parameter names of the search.
70
71
        Returns
72
        -------
73
        list of str
74
            The parameter names of the search parameters.
75
        """
76
        return list(self.estimator.get_params().keys())
77
78
    def _score(self, params):
79
        """Score the parameters.
80
81
        Parameters
82
        ----------
83
        params : dict with string keys
84
            Parameters to score.
85
86
        Returns
87
        -------
88
        float
89
            The score of the parameters.
90
        dict
91
            Additional metadata about the search.
92
        """
93
        estimator = clone(self.estimator)
94
        estimator.set_params(**params)
95
96
        cv_results = cross_validate(
97
            estimator,
98
            self.X,
99
            self.y,
100
            cv=self.cv,
101
        )
102
103
        add_info_d = {
104
            "score_time": cv_results["score_time"],
105
            "fit_time": cv_results["fit_time"],
106
            "n_test_samples": _num_samples(self.X),
107
        }
108
109
        return cv_results["test_score"].mean(), add_info_d
110
111
    @classmethod
112
    def get_test_params(cls, parameter_set="default"):
113
        """Return testing parameter settings for the skbase object.
114
115
        ``get_test_params`` is a unified interface point to store
116
        parameter settings for testing purposes. This function is also
117
        used in ``create_test_instance`` and ``create_test_instances_and_names``
118
        to construct test instances.
119
120
        ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``.
121
122
        Each ``dict`` is a parameter configuration for testing,
123
        and can be used to construct an "interesting" test instance.
124
        A call to ``cls(**params)`` should
125
        be valid for all dictionaries ``params`` in the return of ``get_test_params``.
126
127
        The ``get_test_params`` need not return fixed lists of dictionaries,
128
        it can also return dynamic or stochastic parameter settings.
129
130
        Parameters
131
        ----------
132
        parameter_set : str, default="default"
133
            Name of the set of test parameters to return, for use in tests. If no
134
            special parameters are defined for a value, will return `"default"` set.
135
136
        Returns
137
        -------
138
        params : dict or list of dict, default = {}
139
            Parameters to create testing instances of the class
140
            Each dict are parameters to construct an "interesting" test instance, i.e.,
141
            `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
142
            `create_test_instance` uses the first (or only) dictionary in `params`
143
        """
144
        from sklearn.datasets import load_diabetes, load_iris
145
        from sklearn.svm import SVC, SVR
146
        from sklearn.metrics import accuracy_score, mean_absolute_error
147
        from sklearn.model_selection import KFold
148
149
        X, y = load_iris(return_X_y=True)
150
        params_classif = {
151
            "estimator": SVC(),
152
            "scoring": accuracy_score,
153
            "cv": KFold(n_splits=3, shuffle=True),
154
            "X": X,
155
            "y": y,
156
        }
157
158
        X, y = load_diabetes(return_X_y=True)
159
        params_regress = {
160
            "estimator": SVR(),
161
            "scoring": mean_absolute_error,
162
            "cv": KFold(n_splits=2, shuffle=True),
163
            "X": X,
164
            "y": y,
165
        }
166
        return [params_classif, params_regress]
167
168
    @classmethod
169
    def _get_score_params(self):
170
        """Return settings for the score function.
171
172
        Returns a list, the i-th element corresponds to self.get_test_params()[i].
173
        It should be a valid call for self.score.
174
175
        Returns
176
        -------
177
        list of dict
178
            The parameters to be used for scoring.
179
        """
180
        score_params_classif = {"C": 1.0, "kernel": "linear"}
181
        score_params_regress = {"C": 1.0, "kernel": "linear"}
182
        return [score_params_classif, score_params_regress]
183