_safe_refit()   A
last analyzed

Complexity

Conditions 3

Size

Total Lines 10
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 6
dl 0
loc 10
rs 10
c 0
b 0
f 0
cc 3
nop 4
1
"""Internal helpers that bridge behavioural differences between scikit-learn versions.
2
3
Import *private* scikit-learn symbols **only** here and nowhere else.
4
5
Copyright: Hyperactive contributors
6
License: MIT
7
"""
8
9
from __future__ import annotations
10
11
import warnings
12
from typing import Any
13
14
import sklearn
15
from packaging import version
16
from sklearn.utils.validation import indexable
17
18
_SK_VERSION = version.parse(sklearn.__version__)
19
20
21
def _safe_validate_X_y(estimator, X, y):
22
    """
23
    Version-independent replacement for naive validate_data(X, y).
24
25
    • Ensures X is 2-D.
26
    • Allows y to stay 1-D (required by scikit-learn >=1.7 checks).
27
    • Uses BaseEstimator._validate_data when available so that
28
      estimator tags and sample-weight checks keep working.
29
    """
30
    X, y = indexable(X, y)
31
32
    if hasattr(estimator, "_validate_data"):
33
        return estimator._validate_data(
34
            X,
35
            y,
36
            validate_separately=(
37
                {"ensure_2d": True},  # parameters for X
38
                {"ensure_2d": False},  # parameters for y
39
            ),
40
        )
41
42
    # Fallback for very old scikit-learn versions (<0.23)
43
    from sklearn.utils.validation import check_X_y
44
45
    return check_X_y(X, y, ensure_2d=True)
46
47
48
def _safe_refit(estimator, X, y, fit_params):
49
    if estimator.refit:
50
        estimator._refit(X, y, **fit_params)
51
52
        # make the wrapper itself expose n_features_in_
53
        if hasattr(estimator.best_estimator_, "n_features_in_"):
54
            estimator.n_features_in_ = estimator.best_estimator_.n_features_in_
55
    else:
56
        # Even when `refit=False` we must satisfy the contract
57
        estimator.n_features_in_ = X.shape[1]
58
59
60
# Replacement for `_deprecate_Xt_in_inverse_transform`
61
if _SK_VERSION < version.parse("1.7"):
62
    # Still exists → re-export
63
    from sklearn.utils.deprecation import _deprecate_Xt_in_inverse_transform
64
else:
65
    # Removed in 1.7 - provide drop-in replacement
66
    def _deprecate_Xt_in_inverse_transform(  # noqa: N802  keep sklearn's name
67
        X: Any | None,
68
        Xt: Any | None,
69
    ):
70
        """Handle deprecation of Xt parameter in inverse_transform.
71
72
        scikit-learn ≤1.6 accepted both the old `Xt` parameter and the new
73
        `X` parameter for `inverse_transform`.  When only `Xt` is given we
74
        return `Xt` and raise a deprecation warning (same behaviour that
75
        scikit-learn had before 1.7); otherwise we return `X`.
76
        """
77
        if Xt is not None:
78
            warnings.warn(
79
                "'Xt' was deprecated in scikit-learn 1.2 and has been "
80
                "removed in 1.7; use the positional argument 'X' instead.",
81
                FutureWarning,
82
                stacklevel=2,
83
            )
84
            return Xt
85
        return X
86
87
88
# Replacement for `_check_method_params`
89
try:
90
    from sklearn.utils.validation import _check_method_params  # noqa: F401
91
except ImportError:  # fallback for future releases
92
93
    def _check_method_params(  # type: ignore[override]  # noqa: N802
94
        X,
95
        params: dict[str, Any],
96
    ):
97
        # passthrough - rely on estimator & indexable for validation
98
        return params
99
100
101
__all__ = [
102
    "_deprecate_Xt_in_inverse_transform",
103
    "_check_method_params",
104
]
105