Total Complexity | 4 |
Total Lines | 39 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | # Author: Simon Blanke |
||
2 | # Email: [email protected] |
||
3 | # License: MIT License |
||
4 | |||
5 | |||
6 | from sklearn.utils.validation import ( |
||
7 | indexable, |
||
8 | _check_method_params, |
||
9 | check_is_fitted, |
||
10 | ) |
||
11 | |||
12 | # NOTE Implementations of following methods from: |
||
13 | # https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/model_selection/_search.py |
||
14 | # Tag: 1.5.1 |
||
15 | |||
16 | |||
17 | def _check_refit(search_cv, attr): |
||
18 | if not search_cv.refit: |
||
19 | raise AttributeError( |
||
20 | f"This {type(search_cv).__name__} instance was initialized with " |
||
21 | f"`refit=False`. {attr} is available only after refitting on the best " |
||
22 | "parameters. You can refit an estimator manually using the " |
||
23 | "`best_params_` attribute" |
||
24 | ) |
||
25 | |||
26 | |||
27 | def _estimator_has(attr): |
||
28 | def check(self): |
||
29 | _check_refit(self, attr) |
||
30 | if hasattr(self, "best_estimator_"): |
||
31 | # raise an AttributeError if `attr` does not exist |
||
32 | getattr(self.best_estimator_, attr) |
||
33 | return True |
||
34 | # raise an AttributeError if `attr` does not exist |
||
35 | getattr(self.estimator, attr) |
||
36 | return True |
||
37 | |||
38 | return check |
||
39 |