| Total Complexity | 4 |
| Total Lines | 37 |
| Duplicated Lines | 0 % |
| Changes | 0 | ||
| 1 | # Author: Simon Blanke |
||
| 2 | # Email: [email protected] |
||
| 3 | # License: MIT License |
||
| 4 | |||
| 5 | |||
| 6 | from sklearn.model_selection import cross_validate |
||
| 7 | from sklearn.utils.validation import _num_samples |
||
| 8 | |||
| 9 | |||
| 10 | class ObjectiveFunctionAdapter: |
||
| 11 | def __init__(self, estimator) -> None: |
||
| 12 | self.estimator = estimator |
||
| 13 | |||
| 14 | def add_dataset(self, X, y): |
||
| 15 | self.X = X |
||
| 16 | self.y = y |
||
| 17 | |||
| 18 | def add_validation(self, scoring, cv): |
||
| 19 | self.scoring = scoring |
||
| 20 | self.cv = cv |
||
| 21 | |||
| 22 | def objective_function(self, params): |
||
| 23 | cv_results = cross_validate( |
||
| 24 | self.estimator, |
||
| 25 | self.X, |
||
| 26 | self.y, |
||
| 27 | cv=self.cv, |
||
| 28 | ) |
||
| 29 | |||
| 30 | add_info_d = { |
||
| 31 | "score_time": cv_results["score_time"], |
||
| 32 | "fit_time": cv_results["fit_time"], |
||
| 33 | "n_test_samples": _num_samples(self.X), |
||
| 34 | } |
||
| 35 | |||
| 36 | return cv_results["test_score"].mean(), add_info_d |
||
| 37 |