RandomSearchSk._check_param_distributions()   C
last analyzed

Complexity

Conditions 10

Size

Total Lines 24
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 15
dl 0
loc 24
rs 5.9999
c 0
b 0
f 0
cc 10
nop 2

How to fix   Complexity   

Complexity

Complex classes like hyperactive.opt.random_search.RandomSearchSk._check_param_distributions() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
"""Grid search optimizer."""
2
3
# copyright: hyperactive developers, MIT License (see LICENSE file)
4
5
from collections.abc import Sequence
6
7
import numpy as np
8
from sklearn.model_selection import ParameterSampler
9
10
from hyperactive.base import BaseOptimizer
11
from hyperactive.opt._common import _score_params
12
from hyperactive.utils.parallel import parallelize
13
14
15
class RandomSearchSk(BaseOptimizer):
16
    """Random search optimizer leveraging sklearn's ``ParameterSampler``.
17
18
    Parameters
19
    ----------
20
    param_distributions : dict[str, list | scipy.stats.rv_frozen]
21
        Search space specification. Discrete lists are sampled uniformly;
22
        scipy distribution objects are sampled via their ``rvs`` method.
23
24
    n_iter : int, default=10
25
        Number of parameter sets to evaluate.
26
27
    random_state : int | np.random.RandomState | None, default=None
28
        Controls the pseudo-random generator for reproducibility.
29
30
    error_score : float, default=np.nan
31
        Score assigned when the experiment raises an exception.
32
33
    backend : {"dask", "loky", "multiprocessing", "threading", "ray"}, default = "None".
34
        Parallelization backend to use in the search process.
35
36
        - "None": executes loop sequentally, simple list comprehension
37
        - "loky", "multiprocessing" and "threading": uses ``joblib.Parallel`` loops
38
        - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``
39
        - "dask": uses ``dask``, requires ``dask`` package in environment
40
        - "ray": uses ``ray``, requires ``ray`` package in environment
41
42
    backend_params : dict, optional
43
        additional parameters passed to the backend as config.
44
        Directly passed to ``utils.parallel.parallelize``.
45
        Valid keys depend on the value of ``backend``:
46
47
        - "None": no additional parameters, ``backend_params`` is ignored
48
        - "loky", "multiprocessing" and "threading": default ``joblib`` backends
49
          any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
50
          with the exception of ``backend`` which is directly controlled by ``backend``.
51
          If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
52
          will default to ``joblib`` defaults.
53
        - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``.
54
          any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
55
          ``backend`` must be passed as a key of ``backend_params`` in this case.
56
          If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
57
          will default to ``joblib`` defaults.
58
        - "dask": any valid keys for ``dask.compute`` can be passed, e.g., ``scheduler``
59
60
        - "ray": The following keys can be passed:
61
62
            - "ray_remote_args": dictionary of valid keys for ``ray.init``
63
            - "shutdown_ray": bool, default=True; False prevents ``ray`` from shutting
64
                down after parallelization.
65
            - "logger_name": str, default="ray"; name of the logger to use.
66
            - "mute_warnings": bool, default=False; if True, suppresses warnings
67
68
    experiment : BaseExperiment, optional
69
        Callable returning a scalar score when invoked with keyword
70
        arguments matching a parameter set.
71
72
    Example
73
    -------
74
    Random search with different backend configurations:
75
76
    >>> from hyperactive.opt import RandomSearchSk
77
    >>> from scipy.stats import uniform
78
    >>> param_distributions = {
79
    ...     "C": uniform(loc=0.1, scale=10),
80
    ...     "gamma": ["scale", "auto", 0.001, 0.01, 0.1, 1],
81
    ... }
82
    >>>
83
    >>> # Sequential execution
84
    >>> random_search = RandomSearchSk(
85
    ...     param_distributions=param_distributions,
86
    ...     n_iter=20,
87
    ...     backend="None",
88
    ... )
89
    >>>
90
    >>> # Parallel execution with threading backend
91
    >>> random_search_parallel = RandomSearchSk(
92
    ...     param_distributions=param_distributions,
93
    ...     n_iter=20,
94
    ...     backend="threading",
95
    ...     backend_params={"n_jobs": 2},
96
    ... )
97
98
    Attributes
99
    ----------
100
    best_params_ : dict[str, Any]
101
        Hyper-parameter configuration with the best (lowest) score.
102
    best_score_ : float
103
        Score achieved by ``best_params_``.
104
    best_index_ : int
105
        Index of ``best_params_`` in the sampled sequence.
106
    """
107
108
    def __init__(
109
        self,
110
        param_distributions=None,
111
        n_iter=10,
112
        random_state=None,
113
        error_score=np.nan,
114
        backend="None",
115
        backend_params=None,
116
        experiment=None,
117
    ):
118
        self.experiment = experiment
119
        self.param_distributions = param_distributions
120
        self.n_iter = n_iter
121
        self.random_state = random_state
122
        self.error_score = error_score
123
        self.backend = backend
124
        self.backend_params = backend_params
125
126
        super().__init__()
127
128
    @staticmethod
129
    def _is_distribution(obj) -> bool:
130
        """Return True if *obj* looks like a scipy frozen distribution."""
131
        return callable(getattr(obj, "rvs", None))
132
133
    def _check_param_distributions(self, param_distributions):
134
        """Validate ``param_distributions`` similar to sklearn ≤1.0.x."""
135
        if hasattr(param_distributions, "items"):
136
            param_distributions = [param_distributions]
137
138
        for p in param_distributions:
139
            for name, v in p.items():
140
                if self._is_distribution(v):
141
                    # Assume scipy frozen distribution: nothing to check
142
                    continue
143
144
                if isinstance(v, np.ndarray) and v.ndim > 1:
145
                    raise ValueError("Parameter array should be one-dimensional.")
146
147
                if isinstance(v, str) or not isinstance(v, (np.ndarray, Sequence)):
148
                    raise ValueError(
149
                        f"Parameter distribution for ({name}) must be a list, numpy "
150
                        f"array, or scipy.stats ``rv_frozen``, but got ({type(v)})."
151
                        " Single values need to be wrapped in a sequence."
152
                    )
153
154
                if len(v) == 0:
155
                    raise ValueError(
156
                        f"Parameter values for ({name}) need to be a "
157
                        "non-empty sequence."
158
                    )
159
160
    def _solve(
161
        self,
162
        experiment,
163
        param_distributions,
164
        n_iter,
165
        random_state,
166
        error_score,
167
        backend,
168
        backend_params,
169
    ):
170
        """Sample ``n_iter`` points and return the best parameter set."""
171
        self._check_param_distributions(param_distributions)
172
173
        sampler = ParameterSampler(
174
            param_distributions=param_distributions,
175
            n_iter=n_iter,
176
            random_state=random_state,
177
        )
178
        candidate_params = list(sampler)
179
180
        meta = {
181
            "experiment": experiment,
182
            "error_score": error_score,
183
        }
184
185
        scores = parallelize(
186
            fun=_score_params,
187
            iter=candidate_params,
188
            meta=meta,
189
            backend=backend,
190
            backend_params=backend_params,
191
        )
192
193
        best_index = int(np.argmin(scores))  # lower-is-better convention
194
        best_params = candidate_params[best_index]
195
196
        # public attributes for external consumers
197
        self.best_index_ = best_index
198
        self.best_score_ = float(scores[best_index])
199
        self.best_params_ = best_params
200
201
        return best_params
202
203
    @classmethod
204
    def get_test_params(cls, parameter_set: str = "default"):
205
        """Provide deterministic toy configurations for unit tests."""
206
        from hyperactive.experiment.integrations import SklearnCvExperiment
207
        from hyperactive.experiment.toy import Ackley
208
209
        # 1) ML example (Iris + SVC)
210
        sklearn_exp = SklearnCvExperiment.create_test_instance()
211
        param_dist_1 = {
212
            "C": [0.01, 0.1, 1, 10],
213
            "gamma": np.logspace(-4, 1, 6),
214
        }
215
        params_sklearn = {
216
            "experiment": sklearn_exp,
217
            "param_distributions": param_dist_1,
218
            "n_iter": 5,
219
            "random_state": 42,
220
        }
221
222
        # 2) continuous optimisation example (Ackley)
223
        ackley_exp = Ackley.create_test_instance()
224
        param_dist_2 = {
225
            "x0": np.linspace(-5, 5, 50),
226
            "x1": np.linspace(-5, 5, 50),
227
        }
228
        params_ackley = {
229
            "experiment": ackley_exp,
230
            "param_distributions": param_dist_2,
231
            "n_iter": 20,
232
            "random_state": 0,
233
        }
234
235
        params = [params_sklearn, params_ackley]
236
237
        from hyperactive.utils.parallel import _get_parallel_test_fixtures
238
239
        parallel_fixtures = _get_parallel_test_fixtures()
240
241
        for x in parallel_fixtures:
242
            new_ackley = params_ackley.copy()
243
            new_ackley.update(x)
244
            params.append(new_ackley)
245
246
        return params
247