Passed
Push — master ( 964643...4a40b1 )
by Simon
01:40
created

_safe_refit()   A

Complexity

Conditions 3

Size

Total Lines 10
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 3
nop 4
1
"""
2
Internal helpers that bridge behavioural differences between
3
scikit-learn versions.  Import *private* scikit-learn symbols **only**
4
here and nowhere else.
5
6
Copyright: Hyperactive contributors
7
License: MIT
8
"""
9
10
from __future__ import annotations
11
12
import warnings
13
from typing import Dict, Any
14
15
import sklearn
16
from packaging import version
17
from sklearn.utils.validation import indexable
18
19
_SK_VERSION = version.parse(sklearn.__version__)
20
21
22
def _safe_validate_X_y(estimator, X, y):
23
    """
24
    Version-independent replacement for naive validate_data(X, y).
25
26
    • Ensures X is 2-D.
27
    • Allows y to stay 1-D (required by scikit-learn >=1.7 checks).
28
    • Uses BaseEstimator._validate_data when available so that
29
      estimator tags and sample-weight checks keep working.
30
    """
31
    X, y = indexable(X, y)
32
33
    if hasattr(estimator, "_validate_data"):
34
        return estimator._validate_data(
35
            X,
36
            y,
37
            validate_separately=(
38
                {"ensure_2d": True},  # parameters for X
39
                {"ensure_2d": False},  # parameters for y
40
            ),
41
        )
42
43
    # Fallback for very old scikit-learn versions (<0.23)
44
    from sklearn.utils.validation import check_X_y
45
46
    return check_X_y(X, y, ensure_2d=True)
47
48
49
def _safe_refit(estimator, X, y, fit_params):
50
    if estimator.refit:
51
        estimator._refit(X, y, **fit_params)
52
53
        # make the wrapper itself expose n_features_in_
54
        if hasattr(estimator.best_estimator_, "n_features_in_"):
55
            estimator.n_features_in_ = estimator.best_estimator_.n_features_in_
56
    else:
57
        # Even when `refit=False` we must satisfy the contract
58
        estimator.n_features_in_ = X.shape[1]
59
60
61
# Replacement for `_deprecate_Xt_in_inverse_transform`
62
if _SK_VERSION < version.parse("1.7"):
63
    # Still exists → re-export
64
    from sklearn.utils.deprecation import _deprecate_Xt_in_inverse_transform
65
else:
66
    # Removed in 1.7 → provide drop-in replacement
67
    def _deprecate_Xt_in_inverse_transform(  # noqa: N802  keep sklearn’s name
68
        X: Any | None,
69
        Xt: Any | None,
70
    ):
71
        """
72
        scikit-learn ≤1.6 accepted both the old `Xt` parameter and the new
73
        `X` parameter for `inverse_transform`.  When only `Xt` is given we
74
        return `Xt` and raise a deprecation warning (same behaviour that
75
        scikit-learn had before 1.7); otherwise we return `X`.
76
        """
77
        if Xt is not None:
78
            warnings.warn(
79
                "'Xt' was deprecated in scikit-learn 1.2 and has been "
80
                "removed in 1.7; use the positional argument 'X' instead.",
81
                FutureWarning,
82
                stacklevel=2,
83
            )
84
            return Xt
85
        return X
86
87
88
# Replacement for `_check_method_params`
89
try:
90
    from sklearn.utils.validation import _check_method_params  # noqa: F401
91
except ImportError:  # fallback for future releases
92
93
    def _check_method_params(  # type: ignore[override]  # noqa: N802
94
        X,
95
        params: Dict[str, Any],
96
    ):
97
        # passthrough – rely on estimator & indexable for validation
98
        return params
99
100
101
__all__ = [
102
    "_deprecate_Xt_in_inverse_transform",
103
    "_check_method_params",
104
]
105