ForecastingOptCV._predict()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 32
Code Lines 5

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 5
dl 0
loc 32
rs 10
c 0
b 0
f 0
cc 2
nop 3
1
# copyright: hyperactive developers, MIT License (see LICENSE file)
2
3
import numpy as np
4
from skbase.utils.dependencies import _check_soft_dependencies
5
6
if _check_soft_dependencies("sktime", severity="none"):
7
    from sktime.forecasting.base._delegate import _DelegatedForecaster
8
else:
9
    from skbase.base import BaseEstimator as _DelegatedForecaster
10
11
from hyperactive.experiment.integrations.sktime_forecasting import (
12
    SktimeForecastingExperiment,
13
)
14
15
16
class ForecastingOptCV(_DelegatedForecaster):
17
    """Tune an sktime forecaster via any optimizer in the hyperactive API.
18
19
    Parameters
20
    ----------
21
    forecaster : sktime forecaster, BaseForecaster instance or interface compatible
22
        The forecaster to tune, must implement the sktime forecaster interface.
23
24
    optimizer : hyperactive BaseOptimizer
25
        The optimizer to be used for hyperparameter search.
26
27
    cv : sktime BaseSplitter descendant
28
        determines split of ``y`` and possibly ``X`` into test and train folds
29
        y is always split according to ``cv``, see above
30
        if ``cv_X`` is not passed, ``X`` splits are subset to ``loc`` equal to ``y``
31
        if ``cv_X`` is passed, ``X`` is split according to ``cv_X``
32
33
    strategy : {"refit", "update", "no-update_params"}, optional, default="refit"
34
        defines the ingestion mode when the forecaster sees new data when window expands
35
        "refit" = forecaster is refitted to each training window
36
        "update" = forecaster is updated with training window data, in sequence provided
37
        "no-update_params" = fit to first training window, re-used without fit or update
38
39
    update_behaviour : str, optional, default = "full_refit"
40
        one of {"full_refit", "inner_only", "no_update"}
41
        behaviour of the forecaster when calling update
42
        "full_refit" = both tuning parameters and inner estimator refit on all data seen
43
        "inner_only" = tuning parameters are not re-tuned, inner estimator is updated
44
        "no_update" = neither tuning parameters nor inner estimator are updated
45
46
    scoring : sktime metric (BaseMetric), str, or callable, optional (default=None)
47
        scoring metric to use in tuning the forecaster
48
49
        * sktime metric objects (BaseMetric) descendants can be searched
50
        with the ``registry.all_estimators`` search utility,
51
        for instance via ``all_estimators("metric", as_dataframe=True)``
52
53
        * If callable, must have signature
54
        ``(y_true: 1D np.ndarray, y_pred: 1D np.ndarray) -> float``,
55
        assuming np.ndarrays being of the same length, and lower being better.
56
        Metrics in sktime.performance_metrics.forecasting are all of this form.
57
58
        * If str, uses registry.resolve_alias to resolve to one of the above.
59
          Valid strings are valid registry.craft specs, which include
60
          string repr-s of any BaseMetric object, e.g., "MeanSquaredError()";
61
          and keys of registry.ALIAS_DICT referring to metrics.
62
63
        * If None, defaults to MeanAbsolutePercentageError()
64
65
    refit : bool, optional (default=True)
66
        True = refit the forecaster with the best parameters on the entire data in fit
67
        False = no refitting takes place. The forecaster cannot be used to predict.
68
        This is to be used to tune the hyperparameters, and then use the estimator
69
        as a parameter estimator, e.g., via get_fitted_params or PluginParamsForecaster.
70
71
    error_score : "raise" or numeric, default=np.nan
72
        Value to assign to the score if an exception occurs in estimator fitting. If set
73
        to "raise", the exception is raised. If a numeric value is given,
74
        FitFailedWarning is raised.
75
76
    cv_X : sktime BaseSplitter descendant, optional
77
        determines split of ``X`` into test and train folds
78
        default is ``X`` being split to identical ``loc`` indices as ``y``
79
        if passed, must have same number of splits as ``cv``
80
81
    backend : string, by default "None".
82
        Parallelization backend to use for runs.
83
        Runs parallel evaluate if specified and ``strategy="refit"``.
84
85
        - "None": executes loop sequentially, simple list comprehension
86
        - "loky", "multiprocessing" and "threading": uses ``joblib.Parallel`` loops
87
        - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``
88
        - "dask": uses ``dask``, requires ``dask`` package in environment
89
        - "dask_lazy": same as "dask",
90
          but changes the return to (lazy) ``dask.dataframe.DataFrame``.
91
        - "ray": uses ``ray``, requires ``ray`` package in environment
92
93
        Recommendation: Use "dask" or "loky" for parallel evaluate.
94
        "threading" is unlikely to see speed ups due to the GIL and the serialization
95
        backend (``cloudpickle``) for "dask" and "loky" is generally more robust
96
        than the standard ``pickle`` library used in "multiprocessing".
97
98
    backend_params : dict, optional
99
        additional parameters passed to the backend as config.
100
        Directly passed to ``utils.parallel.parallelize``.
101
        Valid keys depend on the value of ``backend``:
102
103
        - "None": no additional parameters, ``backend_params`` is ignored
104
        - "loky", "multiprocessing" and "threading": default ``joblib`` backends
105
          any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
106
          with the exception of ``backend`` which is directly controlled by ``backend``.
107
          If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
108
          will default to ``joblib`` defaults.
109
        - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``.
110
          any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
111
          ``backend`` must be passed as a key of ``backend_params`` in this case.
112
          If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
113
          will default to ``joblib`` defaults.
114
        - "dask": any valid keys for ``dask.compute`` can be passed,
115
          e.g., ``scheduler``
116
117
        - "ray": The following keys can be passed:
118
119
            - "ray_remote_args": dictionary of valid keys for ``ray.init``
120
            - "shutdown_ray": bool, default=True; False prevents ``ray`` from shutting
121
                down after parallelization.
122
            - "logger_name": str, default="ray"; name of the logger to use.
123
            - "mute_warnings": bool, default=False; if True, suppresses warnings
124
125
    Example
126
    -------
127
    Tuning an sktime forecaster via grid search
128
129
    1. defining the tuned estimator:
130
    >>> from sktime.forecasting.naive import NaiveForecaster
131
    >>> from sktime.split import ExpandingWindowSplitter
132
    >>> from hyperactive.integrations.sktime import ForecastingOptCV
133
    >>> from hyperactive.opt import GridSearchSk as GridSearch
134
    >>>
135
    >>> param_grid = {"strategy": ["mean", "last", "drift"]}
136
    >>> tuned_naive = ForecastingOptCV(
137
    ...     NaiveForecaster(),
138
    ...     GridSearch(param_grid),
139
    ...     cv=ExpandingWindowSplitter(
140
    ...         initial_window=12, step_length=3, fh=range(1, 13)
141
    ...     ),
142
    ... )
143
144
    2. fitting the tuned estimator:
145
    >>> from sktime.datasets import load_airline
146
    >>> from sktime.split import temporal_train_test_split
147
    >>> y = load_airline()
148
    >>> y_train, y_test = temporal_train_test_split(y, test_size=12)
149
    >>>
150
    >>> tuned_naive.fit(y_train, fh=range(1, 13))
151
    ForecastingOptCV(...)
152
    >>> y_pred = tuned_naive.predict()
153
154
    3. obtaining best parameters and best estimator
155
    >>> best_params = tuned_naive.best_params_
156
    >>> best_estimator = tuned_naive.best_forecaster_
157
    """
158
159
    _tags = {
160
        "authors": "fkiraly",
161
        "maintainers": "fkiraly",
162
        "python_dependencies": "sktime",
163
    }
164
165
    # attribute for _DelegatedForecaster, which then delegates
166
    #     all non-overridden methods are same as of getattr(self, _delegate_name)
167
    #     see further details in _DelegatedForecaster docstring
168
    _delegate_name = "best_forecaster_"
169
170
    def __init__(
171
        self,
172
        forecaster,
173
        optimizer,
174
        cv,
175
        strategy="refit",
176
        update_behaviour="full_refit",
177
        scoring=None,
178
        refit=True,
179
        error_score=np.nan,
180
        cv_X=None,
181
        backend=None,
182
        backend_params=None,
183
    ):
184
        self.forecaster = forecaster
185
        self.optimizer = optimizer
186
        self.cv = cv
187
        self.strategy = strategy
188
        self.update_behaviour = update_behaviour
189
        self.scoring = scoring
190
        self.refit = refit
191
        self.error_score = error_score
192
        self.cv_X = cv_X
193
        self.backend = backend
194
        self.backend_params = backend_params
195
        super().__init__()
196
197
    def _fit(self, y, X, fh):
198
        """Fit to training data.
199
200
        Parameters
201
        ----------
202
        y : pd.Series
203
            Target time series to which to fit the forecaster.
204
        fh : int, list or np.array, optional (default=None)
205
            The forecasters horizon with the steps ahead to to predict.
206
        X : pd.DataFrame, optional (default=None)
207
            Exogenous variables are ignored
208
209
        Returns
210
        -------
211
        self : returns an instance of self.
212
        """
213
        from sktime.utils.validation.forecasting import check_scoring
214
215
        forecaster = self.forecaster.clone()
216
217
        scoring = check_scoring(self.scoring, obj=self)
218
        # scoring_name = f"test_{scoring.name}"
219
220
        experiment = SktimeForecastingExperiment(
221
            forecaster=forecaster,
222
            scoring=scoring,
223
            cv=self.cv,
224
            X=X,
225
            y=y,
226
            strategy=self.strategy,
227
            error_score=self.error_score,
228
            cv_X=self.cv_X,
229
            backend=self.backend,
230
            backend_params=self.backend_params,
231
        )
232
233
        optimizer = self.optimizer.clone()
234
        optimizer.set_params(experiment=experiment)
235
        best_params = optimizer.solve()
236
237
        self.best_params_ = best_params
238
        self.best_forecaster_ = forecaster.set_params(**best_params)
239
240
        # Refit model with best parameters.
241
        if self.refit:
242
            self.best_forecaster_.fit(y=y, X=X, fh=fh)
243
244
        return self
245
246
    def _predict(self, fh, X):
247
        """Forecast time series at future horizon.
248
249
        private _predict containing the core logic, called from predict
250
251
        State required:
252
            Requires state to be "fitted".
253
254
        Accesses in self:
255
            Fitted model attributes ending in "_"
256
            self.cutoff
257
258
        Parameters
259
        ----------
260
        fh : guaranteed to be ForecastingHorizon or None, optional (default=None)
261
            The forecasting horizon with the steps ahead to to predict.
262
            If not passed in _fit, guaranteed to be passed here
263
        X : pd.DataFrame, optional (default=None)
264
            Exogenous time series
265
266
        Returns
267
        -------
268
        y_pred : pd.Series
269
            Point predictions
270
        """
271
        if not self.refit:
272
            raise RuntimeError(
273
                f"In {self.__class__.__name__}, refit must be True to make predictions,"
274
                f" but found refit=False. If refit=False, {self.__class__.__name__} can"
275
                " be used only to tune hyper-parameters, as a parameter estimator."
276
            )
277
        return super()._predict(fh=fh, X=X)
278
279
    def _update(self, y, X=None, update_params=True):
280
        """Update time series to incremental training data.
281
282
        Parameters
283
        ----------
284
        y : guaranteed to be of a type in self.get_tag("y_inner_mtype")
285
            Time series with which to update the forecaster.
286
            if self.get_tag("scitype:y")=="univariate":
287
                guaranteed to have a single column/variable
288
            if self.get_tag("scitype:y")=="multivariate":
289
                guaranteed to have 2 or more columns
290
            if self.get_tag("scitype:y")=="both": no restrictions apply
291
        X : optional (default=None)
292
            guaranteed to be of a type in self.get_tag("X_inner_mtype")
293
            Exogeneous time series for the forecast
294
        update_params : bool, optional (default=True)
295
            whether model parameters should be updated
296
297
        Returns
298
        -------
299
        self : reference to self
300
        """
301
        update_behaviour = self.update_behaviour
302
303
        if update_behaviour == "full_refit":
304
            super()._update(y=y, X=X, update_params=update_params)
305
        elif update_behaviour == "inner_only":
306
            self.best_forecaster_.update(y=y, X=X, update_params=update_params)
307
        elif update_behaviour == "no_update":
308
            self.best_forecaster_.update(y=y, X=X, update_params=False)
309
        else:
310
            raise ValueError(
311
                'update_behaviour must be one of "full_refit", "inner_only",'
312
                f' or "no_update", but found {update_behaviour}'
313
            )
314
        return self
315
316
    @classmethod
317
    def get_test_params(cls, parameter_set="default"):
318
        """Return testing parameter settings for the estimator.
319
320
        Parameters
321
        ----------
322
        parameter_set : str, default="default"
323
            Name of the set of test parameters to return, for use in tests. If no
324
            special parameters are defined for a value, will return ``"default"`` set.
325
326
        Returns
327
        -------
328
        params : dict or list of dict
329
        """
330
        from sktime.forecasting.naive import NaiveForecaster
331
        from sktime.forecasting.trend import PolynomialTrendForecaster
332
        from sktime.performance_metrics.forecasting import (
333
            MeanAbsolutePercentageError,
334
            mean_absolute_percentage_error,
335
        )
336
        from sktime.split import SingleWindowSplitter
337
338
        from hyperactive.opt.gfo import HillClimbing
339
        from hyperactive.opt.gridsearch import GridSearchSk
340
        from hyperactive.opt.random_search import RandomSearchSk
341
342
        params_gridsearch = {
343
            "forecaster": NaiveForecaster(strategy="mean"),
344
            "cv": SingleWindowSplitter(fh=1),
345
            "optimizer": GridSearchSk(param_grid={"window_length": [2, 5]}),
346
            "scoring": MeanAbsolutePercentageError(symmetric=True),
347
        }
348
        params_randomsearch = {
349
            "forecaster": PolynomialTrendForecaster(),
350
            "cv": SingleWindowSplitter(fh=1),
351
            "optimizer": RandomSearchSk(param_distributions={"degree": [1, 2]}),
352
            "scoring": mean_absolute_percentage_error,
353
            "update_behaviour": "inner_only",
354
        }
355
        params_hillclimb = {
356
            "forecaster": NaiveForecaster(strategy="mean"),
357
            "cv": SingleWindowSplitter(fh=1),
358
            "optimizer": HillClimbing(
359
                search_space={"window_length": [2, 5]},
360
                n_iter=10,
361
                n_neighbours=5,
362
            ),
363
            "scoring": "MeanAbsolutePercentageError(symmetric=True)",
364
            "update_behaviour": "no_update",
365
        }
366
        return [params_gridsearch, params_randomsearch, params_hillclimb]
367