Passed
Push — master ( 65e4f4...bd0b66 )
by Simon
01:38
created

ForecastingOptCV.get_test_params()   A

Complexity

Conditions 1

Size

Total Lines 51
Code Lines 32

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 32
dl 0
loc 51
rs 9.112
c 0
b 0
f 0
cc 1
nop 2

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
# copyright: hyperactive developers, MIT License (see LICENSE file)
2
3
import numpy as np
4
from skbase.utils.dependencies import _check_soft_dependencies
5
6
if _check_soft_dependencies("sktime", severity="none"):
7
    from sktime.forecasting.base._delegate import _DelegatedForecaster
8
else:
9
    from skbase.base import BaseEstimator as _DelegatedForecaster
10
11
from hyperactive.experiment.integrations.sktime_forecasting import (
12
    SktimeForecastingExperiment,
13
)
14
15
16
class ForecastingOptCV(_DelegatedForecaster):
17
    """Tune an sktime forecaster via any optimizer in the hyperactive API.
18
19
    Parameters
20
    ----------
21
    forecaster : sktime forecaster, BaseForecaster instance or interface compatible
22
        The forecaster to tune, must implement the sktime forecaster interface.
23
24
    optimizer : hyperactive BaseOptimizer
25
        The optimizer to be used for hyperparameter search.
26
27
    cv : sktime BaseSplitter descendant
28
        determines split of ``y`` and possibly ``X`` into test and train folds
29
        y is always split according to ``cv``, see above
30
        if ``cv_X`` is not passed, ``X`` splits are subset to ``loc`` equal to ``y``
31
        if ``cv_X`` is passed, ``X`` is split according to ``cv_X``
32
33
    strategy : {"refit", "update", "no-update_params"}, optional, default="refit"
34
        defines the ingestion mode when the forecaster sees new data when window expands
35
        "refit" = forecaster is refitted to each training window
36
        "update" = forecaster is updated with training window data, in sequence provided
37
        "no-update_params" = fit to first training window, re-used without fit or update
38
39
    update_behaviour : str, optional, default = "full_refit"
40
        one of {"full_refit", "inner_only", "no_update"}
41
        behaviour of the forecaster when calling update
42
        "full_refit" = both tuning parameters and inner estimator refit on all data seen
43
        "inner_only" = tuning parameters are not re-tuned, inner estimator is updated
44
        "no_update" = neither tuning parameters nor inner estimator are updated
45
46
    scoring : sktime metric (BaseMetric), str, or callable, optional (default=None)
47
        scoring metric to use in tuning the forecaster
48
49
        * sktime metric objects (BaseMetric) descendants can be searched
50
        with the ``registry.all_estimators`` search utility,
51
        for instance via ``all_estimators("metric", as_dataframe=True)``
52
53
        * If callable, must have signature
54
        ``(y_true: 1D np.ndarray, y_pred: 1D np.ndarray) -> float``,
55
        assuming np.ndarrays being of the same length, and lower being better.
56
        Metrics in sktime.performance_metrics.forecasting are all of this form.
57
58
        * If str, uses registry.resolve_alias to resolve to one of the above.
59
          Valid strings are valid registry.craft specs, which include
60
          string repr-s of any BaseMetric object, e.g., "MeanSquaredError()";
61
          and keys of registry.ALIAS_DICT referring to metrics.
62
63
        * If None, defaults to MeanAbsolutePercentageError()
64
65
    refit : bool, optional (default=True)
66
        True = refit the forecaster with the best parameters on the entire data in fit
67
        False = no refitting takes place. The forecaster cannot be used to predict.
68
        This is to be used to tune the hyperparameters, and then use the estimator
69
        as a parameter estimator, e.g., via get_fitted_params or PluginParamsForecaster.
70
71
    error_score : "raise" or numeric, default=np.nan
72
        Value to assign to the score if an exception occurs in estimator fitting. If set
73
        to "raise", the exception is raised. If a numeric value is given,
74
        FitFailedWarning is raised.
75
76
    cv_X : sktime BaseSplitter descendant, optional
77
        determines split of ``X`` into test and train folds
78
        default is ``X`` being split to identical ``loc`` indices as ``y``
79
        if passed, must have same number of splits as ``cv``
80
81
    backend : string, by default "None".
82
        Parallelization backend to use for runs.
83
        Runs parallel evaluate if specified and ``strategy="refit"``.
84
85
        - "None": executes loop sequentially, simple list comprehension
86
        - "loky", "multiprocessing" and "threading": uses ``joblib.Parallel`` loops
87
        - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``
88
        - "dask": uses ``dask``, requires ``dask`` package in environment
89
        - "dask_lazy": same as "dask",
90
          but changes the return to (lazy) ``dask.dataframe.DataFrame``.
91
        - "ray": uses ``ray``, requires ``ray`` package in environment
92
93
        Recommendation: Use "dask" or "loky" for parallel evaluate.
94
        "threading" is unlikely to see speed ups due to the GIL and the serialization
95
        backend (``cloudpickle``) for "dask" and "loky" is generally more robust
96
        than the standard ``pickle`` library used in "multiprocessing".
97
98
    backend_params : dict, optional
99
        additional parameters passed to the backend as config.
100
        Directly passed to ``utils.parallel.parallelize``.
101
        Valid keys depend on the value of ``backend``:
102
103
        - "None": no additional parameters, ``backend_params`` is ignored
104
        - "loky", "multiprocessing" and "threading": default ``joblib`` backends
105
          any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
106
          with the exception of ``backend`` which is directly controlled by ``backend``.
107
          If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
108
          will default to ``joblib`` defaults.
109
        - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``.
110
          any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
111
          ``backend`` must be passed as a key of ``backend_params`` in this case.
112
          If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
113
          will default to ``joblib`` defaults.
114
        - "dask": any valid keys for ``dask.compute`` can be passed,
115
          e.g., ``scheduler``
116
117
        - "ray": The following keys can be passed:
118
119
            - "ray_remote_args": dictionary of valid keys for ``ray.init``
120
            - "shutdown_ray": bool, default=True; False prevents ``ray`` from shutting
121
                down after parallelization.
122
            - "logger_name": str, default="ray"; name of the logger to use.
123
            - "mute_warnings": bool, default=False; if True, suppresses warnings
124
125
    Example
126
    -------
127
    Tuning sklearn SVC via grid search
128
129
    1. defining the tuned estimator:
130
    >>> from sklearn.svm import SVC
131
    >>> from hyperactive.integrations.sklearn import OptCV
132
    >>> from hyperactive.opt import GridSearchSk as GridSearch
133
    >>>
134
    >>> param_grid = {"kernel": ["linear", "rbf"], "C": [1, 10]}
135
    >>> tuned_svc = OptCV(SVC(), GridSearch(param_grid))
136
137
    2. fitting the tuned estimator:
138
    >>> from sklearn.datasets import load_iris
139
    >>> from sklearn.model_selection import train_test_split
140
    >>> X, y = load_iris(return_X_y=True)
141
    >>> X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
142
    >>>
143
    >>> tuned_svc.fit(X_train, y_train)
144
    OptCV(...)
145
    >>> y_pred = tuned_svc.predict(X_test)
146
147
    3. obtaining best parameters and best estimator
148
    >>> best_params = tuned_svc.best_params_
149
    >>> best_estimator = tuned_svc.best_estimator_
150
    """
151
152
    _tags = {
153
        "authors": "fkiraly",
154
        "maintainers": "fkiraly",
155
        "python_dependencies": "sktime",
156
    }
157
158
    # attribute for _DelegatedForecaster, which then delegates
159
    #     all non-overridden methods are same as of getattr(self, _delegate_name)
160
    #     see further details in _DelegatedForecaster docstring
161
    _delegate_name = "best_forecaster_"
162
163
    def __init__(
164
        self,
165
        forecaster,
166
        optimizer,
167
        cv,
168
        strategy="refit",
169
        update_behaviour="full_refit",
170
        scoring=None,
171
        refit=True,
172
        error_score=np.nan,
173
        cv_X=None,
174
        backend=None,
175
        backend_params=None,
176
    ):
177
        self.forecaster = forecaster
178
        self.optimizer = optimizer
179
        self.cv = cv
180
        self.strategy = strategy
181
        self.update_behaviour = update_behaviour
182
        self.scoring = scoring
183
        self.refit = refit
184
        self.error_score = error_score
185
        self.cv_X = cv_X
186
        self.backend = backend
187
        self.backend_params = backend_params
188
        super().__init__()
189
190
    def _fit(self, y, X, fh):
191
        """Fit to training data.
192
193
        Parameters
194
        ----------
195
        y : pd.Series
196
            Target time series to which to fit the forecaster.
197
        fh : int, list or np.array, optional (default=None)
198
            The forecasters horizon with the steps ahead to to predict.
199
        X : pd.DataFrame, optional (default=None)
200
            Exogenous variables are ignored
201
202
        Returns
203
        -------
204
        self : returns an instance of self.
205
        """
206
        from sktime.utils.validation.forecasting import check_scoring
207
208
        forecaster = self.forecaster.clone()
209
210
        scoring = check_scoring(self.scoring, obj=self)
211
        # scoring_name = f"test_{scoring.name}"
212
213
        experiment = SktimeForecastingExperiment(
214
            forecaster=forecaster,
215
            scoring=scoring,
216
            cv=self.cv,
217
            X=X,
218
            y=y,
219
            strategy=self.strategy,
220
            error_score=self.error_score,
221
            cv_X=self.cv_X,
222
            backend=self.backend,
223
            backend_params=self.backend_params,
224
        )
225
226
        optimizer = self.optimizer.clone()
227
        optimizer.set_params(experiment=experiment)
228
        best_params = optimizer.run()
229
230
        self.best_params_ = best_params
231
        self.best_forecaster_ = forecaster.set_params(**best_params)
232
233
        # Refit model with best parameters.
234
        if self.refit:
235
            self.best_forecaster_.fit(y=y, X=X, fh=fh)
236
237
        return self
238
239
    def _predict(self, fh, X):
240
        """Forecast time series at future horizon.
241
242
        private _predict containing the core logic, called from predict
243
244
        State required:
245
            Requires state to be "fitted".
246
247
        Accesses in self:
248
            Fitted model attributes ending in "_"
249
            self.cutoff
250
251
        Parameters
252
        ----------
253
        fh : guaranteed to be ForecastingHorizon or None, optional (default=None)
254
            The forecasting horizon with the steps ahead to to predict.
255
            If not passed in _fit, guaranteed to be passed here
256
        X : pd.DataFrame, optional (default=None)
257
            Exogenous time series
258
259
        Returns
260
        -------
261
        y_pred : pd.Series
262
            Point predictions
263
        """
264
        if not self.refit:
265
            raise RuntimeError(
266
                f"In {self.__class__.__name__}, refit must be True to make predictions,"
267
                f" but found refit=False. If refit=False, {self.__class__.__name__} can"
268
                " be used only to tune hyper-parameters, as a parameter estimator."
269
            )
270
        return super()._predict(fh=fh, X=X)
271
272
    def _update(self, y, X=None, update_params=True):
273
        """Update time series to incremental training data.
274
275
        Parameters
276
        ----------
277
        y : guaranteed to be of a type in self.get_tag("y_inner_mtype")
278
            Time series with which to update the forecaster.
279
            if self.get_tag("scitype:y")=="univariate":
280
                guaranteed to have a single column/variable
281
            if self.get_tag("scitype:y")=="multivariate":
282
                guaranteed to have 2 or more columns
283
            if self.get_tag("scitype:y")=="both": no restrictions apply
284
        X : optional (default=None)
285
            guaranteed to be of a type in self.get_tag("X_inner_mtype")
286
            Exogeneous time series for the forecast
287
        update_params : bool, optional (default=True)
288
            whether model parameters should be updated
289
290
        Returns
291
        -------
292
        self : reference to self
293
        """
294
        update_behaviour = self.update_behaviour
295
296
        if update_behaviour == "full_refit":
297
            super()._update(y=y, X=X, update_params=update_params)
298
        elif update_behaviour == "inner_only":
299
            self.best_forecaster_.update(y=y, X=X, update_params=update_params)
300
        elif update_behaviour == "no_update":
301
            self.best_forecaster_.update(y=y, X=X, update_params=False)
302
        else:
303
            raise ValueError(
304
                'update_behaviour must be one of "full_refit", "inner_only",'
305
                f' or "no_update", but found {update_behaviour}'
306
            )
307
        return self
308
309
    @classmethod
310
    def get_test_params(cls, parameter_set="default"):
311
        """Return testing parameter settings for the estimator.
312
313
        Parameters
314
        ----------
315
        parameter_set : str, default="default"
316
            Name of the set of test parameters to return, for use in tests. If no
317
            special parameters are defined for a value, will return ``"default"`` set.
318
319
        Returns
320
        -------
321
        params : dict or list of dict
322
        """
323
        from sktime.forecasting.naive import NaiveForecaster
324
        from sktime.forecasting.trend import PolynomialTrendForecaster
325
        from sktime.performance_metrics.forecasting import (
326
            MeanAbsolutePercentageError,
327
            mean_absolute_percentage_error,
328
        )
329
        from sktime.split import SingleWindowSplitter
330
331
        from hyperactive.opt.gfo import HillClimbing
332
        from hyperactive.opt.gridsearch import GridSearchSk
333
        from hyperactive.opt.random_search import RandomSearchSk
334
335
        params_gridsearch = {
336
            "forecaster": NaiveForecaster(strategy="mean"),
337
            "cv": SingleWindowSplitter(fh=1),
338
            "optimizer": GridSearchSk(param_grid={"window_length": [2, 5]}),
339
            "scoring": MeanAbsolutePercentageError(symmetric=True),
340
        }
341
        params_randomsearch = {
342
            "forecaster": PolynomialTrendForecaster(),
343
            "cv": SingleWindowSplitter(fh=1),
344
            "optimizer": RandomSearchSk(param_distributions={"degree": [1, 2]}),
345
            "scoring": mean_absolute_percentage_error,
346
            "update_behaviour": "inner_only",
347
        }
348
        params_hillclimb = {
349
            "forecaster": NaiveForecaster(strategy="mean"),
350
            "cv": SingleWindowSplitter(fh=1),
351
            "optimizer": HillClimbing(
352
                search_space={"window_length": [2, 5]},
353
                n_iter=10,
354
                n_neighbours=5,
355
            ),
356
            "scoring": "MeanAbsolutePercentageError(symmetric=True)",
357
            "update_behaviour": "no_update",
358
        }
359
        return [params_gridsearch, params_randomsearch, params_hillclimb]
360