Passed
Push — master ( 65e4f4...bd0b66 )
by Simon
01:38
created

SktimeForecastingExperiment.get_test_params()   A

Complexity

Conditions 1

Size

Total Lines 56
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 19
dl 0
loc 56
rs 9.45
c 0
b 0
f 0
cc 1
nop 2

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Experiment adapter for sktime backtesting experiments."""
2
# copyright: hyperactive developers, MIT License (see LICENSE file)
3
4
import numpy as np
5
6
from hyperactive.base import BaseExperiment
7
8
9
class SktimeForecastingExperiment(BaseExperiment):
10
    """Experiment adapter for time backtesting experiments.
11
12
    This class is used to perform backtesting experiments using a given
13
    sktime forecaster. It allows for hyperparameter tuning and evaluation of
14
    the model's performance.
15
16
    The score returned is the summary backtesting score,
17
    of applying ``sktime`` ``evaluate`` to ``estimator`` with the parameters given in
18
    ``score`` ``params``.
19
20
    The backtesting performed is specified by the ``cv`` parameter,
21
    and the scoring metric is specified by the ``scoring`` parameter.
22
    The ``X`` and ``y`` parameters are the input data and target values,
23
    which are used in fit/predict cross-validation.
24
25
    Parameters
26
    ----------
27
    forecaster : sktime BaseForecaster descendant (concrete forecaster)
28
        sktime forecaster to benchmark
29
30
    cv : sktime BaseSplitter descendant
31
        determines split of ``y`` and possibly ``X`` into test and train folds
32
        y is always split according to ``cv``, see above
33
        if ``cv_X`` is not passed, ``X`` splits are subset to ``loc`` equal to ``y``
34
        if ``cv_X`` is passed, ``X`` is split according to ``cv_X``
35
36
    y : sktime time series container
37
        Target (endogeneous) time series used in the evaluation experiment
38
39
    X : sktime time series container, of same mtype as y
40
        Exogenous time series used in the evaluation experiment
41
42
    strategy : {"refit", "update", "no-update_params"}, optional, default="refit"
43
        defines the ingestion mode when the forecaster sees new data when window expands
44
        "refit" = forecaster is refitted to each training window
45
        "update" = forecaster is updated with training window data, in sequence provided
46
        "no-update_params" = fit to first training window, re-used without fit or update
47
48
    scoring : subclass of sktime.performance_metrics.BaseMetric,
49
        default=None. Used to get a score function that takes y_pred and y_test
50
        arguments and accept y_train as keyword argument.
51
        If None, then uses scoring = MeanAbsolutePercentageError(symmetric=True).
52
53
    error_score : "raise" or numeric, default=np.nan
54
        Value to assign to the score if an exception occurs in estimator fitting. If set
55
        to "raise", the exception is raised. If a numeric value is given,
56
        FitFailedWarning is raised.
57
58
    cv_X : sktime BaseSplitter descendant, optional
59
        determines split of ``X`` into test and train folds
60
        default is ``X`` being split to identical ``loc`` indices as ``y``
61
        if passed, must have same number of splits as ``cv``
62
63
    backend : string, by default "None".
64
        Parallelization backend to use for runs.
65
        Runs parallel evaluate if specified and ``strategy="refit"``.
66
67
        - "None": executes loop sequentially, simple list comprehension
68
        - "loky", "multiprocessing" and "threading": uses ``joblib.Parallel`` loops
69
        - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``
70
        - "dask": uses ``dask``, requires ``dask`` package in environment
71
        - "dask_lazy": same as "dask",
72
          but changes the return to (lazy) ``dask.dataframe.DataFrame``.
73
        - "ray": uses ``ray``, requires ``ray`` package in environment
74
75
        Recommendation: Use "dask" or "loky" for parallel evaluate.
76
        "threading" is unlikely to see speed ups due to the GIL and the serialization
77
        backend (``cloudpickle``) for "dask" and "loky" is generally more robust
78
        than the standard ``pickle`` library used in "multiprocessing".
79
80
    backend_params : dict, optional
81
        additional parameters passed to the backend as config.
82
        Directly passed to ``utils.parallel.parallelize``.
83
        Valid keys depend on the value of ``backend``:
84
85
        - "None": no additional parameters, ``backend_params`` is ignored
86
        - "loky", "multiprocessing" and "threading": default ``joblib`` backends
87
          any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
88
          with the exception of ``backend`` which is directly controlled by ``backend``.
89
          If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
90
          will default to ``joblib`` defaults.
91
        - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``.
92
          any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
93
          ``backend`` must be passed as a key of ``backend_params`` in this case.
94
          If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
95
          will default to ``joblib`` defaults.
96
        - "dask": any valid keys for ``dask.compute`` can be passed,
97
          e.g., ``scheduler``
98
99
        - "ray": The following keys can be passed:
100
101
            - "ray_remote_args": dictionary of valid keys for ``ray.init``
102
            - "shutdown_ray": bool, default=True; False prevents ``ray`` from shutting
103
                down after parallelization.
104
            - "logger_name": str, default="ray"; name of the logger to use.
105
            - "mute_warnings": bool, default=False; if True, suppresses warnings
106
107
    Example
108
    -------
109
    >>> from hyperactive.experiment.integrations import SktimeForecastingExperiment
110
    >>> from sktime.datasets import load_airline
111
    >>> from sktime.forecasting.naive import NaiveForecaster
112
    >>> from sktime.performance_metrics.forecasting import MeanAbsolutePercentageError
113
    >>> from sktime.split import ExpandingWindowSplitter
114
    >>>
115
    >>> y = load_airline()
116
    >>>
117
    >>> sktime_exp = SktimeForecastingExperiment(
118
    ...     forecaster=NaiveForecaster(strategy="last"),
119
    ...     scoring=MeanAbsolutePercentageError(),
120
    ...     cv=ExpandingWindowSplitter(initial_window=36, step_length=12, fh=12),
121
    ...     y=y,
122
    ... )
123
    >>> params = {"strategy": "mean"}
124
    >>> score, add_info = sktime_exp.score(params)
125
126
    For default choices of ``scoring``:
127
    >>> sktime_exp = SktimeForecastingExperiment(
128
    ...     forecaster=NaiveForecaster(strategy="last"),
129
    ...     cv=ExpandingWindowSplitter(initial_window=36, step_length=12, fh=12),
130
    ...     y=y,
131
    ... )
132
    >>> params = {"strategy": "mean"}
133
    >>> score, add_info = sktime_exp.score(params)
134
135
    Quick call without metadata return or dictionary:
136
    >>> score = sktime_exp(strategy="mean")
137
    """
138
139
    _tags = {
140
        "authors": "fkiraly",
141
        "maintainers": "fkiraly",
142
        "python_dependencies": "sktime",  # python dependencies
143
    }
144
145
    def __init__(
146
        self,
147
        forecaster,
148
        cv,
149
        y,
150
        X=None,
151
        strategy="refit",
152
        scoring=None,
153
        error_score=np.nan,
154
        cv_X=None,
155
        backend=None,
156
        backend_params=None,
157
    ):
158
        self.forecaster = forecaster
159
        self.X = X
160
        self.y = y
161
        self.strategy = strategy
162
        self.scoring = scoring
163
        self.cv = cv
164
        self.error_score = error_score
165
        self.cv_X = cv_X
166
        self.backend = backend
167
        self.backend_params = backend_params
168
169
        super().__init__()
170
171
        if scoring is None:
172
            from sktime.performance_metrics.forecasting import (
173
                MeanAbsolutePercentageError,
174
            )
175
176
            self._scoring = MeanAbsolutePercentageError(symmetric=True)
177
        else:
178
            self._scoring = scoring
179
180
        if scoring is None or scoring.get_tag("lower_is_better", False):
181
            higher_or_lower_better = "lower"
182
        else:
183
            higher_or_lower_better = "higher"
184
        self.set_tags(**{"property:higher_or_lower_is_better": higher_or_lower_better})
185
186
    def _paramnames(self):
187
        """Return the parameter names of the search.
188
189
        Returns
190
        -------
191
        list of str
192
            The parameter names of the search parameters.
193
        """
194
        return list(self.forecaster.get_params().keys())
195
196
    def _evaluate(self, params):
197
        """Evaluate the parameters.
198
199
        Parameters
200
        ----------
201
        params : dict with string keys
202
            Parameters to evaluate.
203
204
        Returns
205
        -------
206
        float
207
            The value of the parameters as per evaluation.
208
        dict
209
            Additional metadata about the search.
210
        """
211
        from sktime.forecasting.model_evaluation import evaluate
212
213
        results = evaluate(
214
            self.forecaster,
215
            cv=self.cv,
216
            y=self.y,
217
            X=self.X,
218
            strategy=self.strategy,
219
            scoring=self._scoring,
220
            error_score=self.error_score,
221
            cv_X=self.cv_X,
222
            backend=self.backend,
223
            backend_params=self.backend_params,
224
        )
225
226
        result_name = f"test_{self._scoring.name}"
227
228
        res_float = results[result_name].mean()
229
230
        return res_float, {"results": results}
231
232
    @classmethod
233
    def get_test_params(cls, parameter_set="default"):
234
        """Return testing parameter settings for the skbase object.
235
236
        ``get_test_params`` is a unified interface point to store
237
        parameter settings for testing purposes. This function is also
238
        used in ``create_test_instance`` and ``create_test_instances_and_names``
239
        to construct test instances.
240
241
        ``get_test_params`` should return a single ``dict``, or a ``list`` of ``dict``.
242
243
        Each ``dict`` is a parameter configuration for testing,
244
        and can be used to construct an "interesting" test instance.
245
        A call to ``cls(**params)`` should
246
        be valid for all dictionaries ``params`` in the return of ``get_test_params``.
247
248
        The ``get_test_params`` need not return fixed lists of dictionaries,
249
        it can also return dynamic or stochastic parameter settings.
250
251
        Parameters
252
        ----------
253
        parameter_set : str, default="default"
254
            Name of the set of test parameters to return, for use in tests. If no
255
            special parameters are defined for a value, will return `"default"` set.
256
257
        Returns
258
        -------
259
        params : dict or list of dict, default = {}
260
            Parameters to create testing instances of the class
261
            Each dict are parameters to construct an "interesting" test instance, i.e.,
262
            `MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
263
            `create_test_instance` uses the first (or only) dictionary in `params`
264
        """
265
        from sktime.datasets import load_airline, load_longley
266
        from sktime.forecasting.naive import NaiveForecaster
267
        from sktime.split import ExpandingWindowSplitter
268
269
        y = load_airline()
270
        params0 = {
271
            "forecaster": NaiveForecaster(strategy="last"),
272
            "cv": ExpandingWindowSplitter(initial_window=36, step_length=12, fh=12),
273
            "y": y,
274
        }
275
276
        from sktime.performance_metrics.forecasting import MeanAbsolutePercentageError
277
278
        y, X = load_longley()
279
        params1 = {
280
            "forecaster": NaiveForecaster(strategy="last"),
281
            "cv": ExpandingWindowSplitter(initial_window=3, step_length=3, fh=1),
282
            "y": y,
283
            "X": X,
284
            "scoring": MeanAbsolutePercentageError(symmetric=False),
285
        }
286
287
        return [params0, params1]
288
289
    @classmethod
290
    def _get_score_params(self):
291
        """Return settings for testing score/evaluate functions. Used in tests only.
292
293
        Returns a list, the i-th element should be valid arguments for
294
        self.evaluate and self.score, of an instance constructed with
295
        self.get_test_params()[i].
296
297
        Returns
298
        -------
299
        list of dict
300
            The parameters to be used for scoring.
301
        """
302
        val0 = {"strategy": "mean"}
303
        val1 = {"strategy": "last"}
304
        return [val0, val1]
305