Passed
Pull Request — master (#101)
by Simon
02:06
created

hyperactive.integrations.sklearn.sklearn_cv_experiment   A

Complexity

Total Complexity 3

Size/Duplication

Total Lines 60
Duplicated Lines 53.33 %

Importance

Changes 0
Metric Value
eloc 28
dl 32
loc 60
rs 10
c 0
b 0
f 0
wmc 3

3 Methods

Rating   Name   Duplication   Size   Complexity  
A SklearnCvExperiment._paramnames() 0 9 1
A SklearnCvExperiment.__init__() 0 6 1
A SklearnCvExperiment._score() 32 32 1

How to fix   Duplicated Code   

Duplicated Code

Duplicate code is one of the most pungent code smells. A rule that is often used is to re-structure code once it is duplicated in three or more places.

Common duplication problems, and corresponding solutions are:

1
"""Experiment adapter for sklearn cross-validation experiments."""
2
3
from sklearn import clone
4
from sklearn.model_selection import cross_validate
5
from sklearn.utils.validation import _num_samples
6
7
from hyperactive.base import BaseExperiment
8
9
class SklearnCvExperiment(BaseExperiment):
10
11
    def __init__(self, estimator, scoring, cv, X, y):
12
        self.estimator = estimator
13
        self.X = X
14
        self.y = y
15
        self.scoring = scoring
16
        self.cv = cv
17
18
    def _paramnames(self):
19
        """Return the parameter names of the search.
20
21
        Returns
22
        -------
23
        list of str
24
            The parameter names of the search parameters.
25
        """
26
        return list(self.estimator.get_params().keys())
27
28 View Code Duplication
    def _score(self, **params):
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
29
        """Score the parameters.
30
31
        Parameters
32
        ----------
33
        params : dict with string keys
34
            Parameters to score.
35
36
        Returns
37
        -------
38
        float
39
            The score of the parameters.
40
        dict
41
            Additional metadata about the search.
42
        """
43
        estimator = clone(self.estimator)
44
        estimator.set_params(**params)
45
46
        cv_results = cross_validate(
47
            self.estimator,
48
            self.X,
49
            self.y,
50
            cv=self.cv,
51
        )
52
53
        add_info_d = {
54
            "score_time": cv_results["score_time"],
55
            "fit_time": cv_results["fit_time"],
56
            "n_test_samples": _num_samples(self.X),
57
        }
58
59
        return cv_results["test_score"].mean(), add_info_d
60