hyperactive.utils.parallel._parallelize_joblib()   A
last analyzed

Complexity

Conditions 5

Size

Total Lines 34
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 34
rs 9.2333
c 0
b 0
f 0
cc 5
nop 5
1
# copied from sktime, BSD-3-Clause License (see LICENSE file)
2
# to be moved to scikit-base in the future
3
"""Common abstraction utilities for parallelization backends.
4
5
New parallelization or iteration backends can be added easily as follows:
6
7
* Add a new backend name to ``backend_dict``, syntax is
8
  backend_name: backend_type, where backend_type collects backend options,
9
  e.g., multiple options for a single parallelization backend.
10
* Add a new function to ``para_dict``, should have name
11
  ``_parallelize_<backend_name>`` and take the same arguments as
12
  ``_parallelize_none``. Ensure that ``backend`` and ``backend_params`` are arguments,
13
  even if there is only one backend option, or no additional parameters.
14
* add the backend string in the docstring of parallelize, and any downstream
15
  functions that use ``parallelize`` and expose the backend parameter an argument
16
"""
17
18
19
def parallelize(fun, iter, meta=None, backend=None, backend_params=None):
20
    """Parallelize loop over iter via backend.
21
22
    Executes ``fun(x, meta=meta)`` in parallel for ``x`` in ``iter``,
23
    and returns the results as a list in the same order as ``iter``.
24
25
    Uses the iteration or parallelization backend specified by ``backend``.
26
27
    Parameters
28
    ----------
29
    fun : callable, must have exactly two arguments, second argument of name "meta"
30
        function to be executed in parallel
31
32
    iter : iterable
33
        iterable over which to parallelize, elements are passed to fun in order,
34
        to the first argument
35
36
    meta : dict, optional
37
        variables to be passed to fun, as the second argument, under the key ``meta``
38
39
    backend : str, optional
40
        backend to use for parallelization, one of
41
42
        - "None": executes loop sequentially, simple list comprehension
43
        - "loky", "multiprocessing" and "threading": uses ``joblib`` ``Parallel`` loops
44
        - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``
45
        - "dask": uses ``dask``, requires ``dask`` package in environment
46
        - "dask_lazy": same as ``"dask"``, but returns delayed object instead of list
47
        - "ray": uses a ray remote to execute jobs in parallel
48
49
    backend_params : dict, optional
50
        additional parameters passed to the backend as config.
51
        Valid keys depend on the value of ``backend``:
52
53
        - "None": no additional parameters, ``backend_params`` is ignored
54
        - "loky", "multiprocessing" and "threading": default ``joblib`` backends
55
          any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
56
          with the exception of ``backend`` which is directly controlled by ``backend``.
57
          If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
58
          will default to ``joblib`` defaults.
59
        - "joblib": custom and 3rd party ``joblib`` backends, e.g., ``spark``.
60
          any valid keys for ``joblib.Parallel`` can be passed here, e.g., ``n_jobs``,
61
          ``backend`` must be passed as a key of ``backend_params`` in this case.
62
          If ``n_jobs`` is not passed, it will default to ``-1``, other parameters
63
          will default to ``joblib`` defaults.
64
        - "dask": any valid keys for ``dask.compute`` can be passed, e.g., ``scheduler``
65
66
        - "ray": The following keys can be passed:
67
68
            - "ray_remote_args": dictionary of valid keys for ``ray.init``
69
            - "shutdown_ray": bool, default=True; False prevents ``ray`` from shutting
70
                down after parallelization.
71
            - "logger_name": str, default="ray"; name of the logger to use.
72
            - "mute_warnings": bool, default=False; if True, suppresses warnings
73
74
    """
75
    if meta is None:
76
        meta = {}
77
    if backend is None:
78
        backend = "None"
79
    if backend_params is None:
80
        backend_params = {}
81
82
    backend_name = backend_dict[backend]
83
    para_fun = para_dict[backend_name]
84
85
    ret = para_fun(
86
        fun=fun, iter=iter, meta=meta, backend=backend, backend_params=backend_params
87
    )
88
    return ret
89
90
91
backend_dict = {
92
    "None": "none",
93
    "loky": "joblib",
94
    "multiprocessing": "joblib",
95
    "threading": "joblib",
96
    "joblib": "joblib",
97
    "dask": "dask",
98
    "dask_lazy": "dask",
99
    "ray": "ray",
100
}
101
para_dict = {}
102
103
104
def _parallelize_none(fun, iter, meta, backend, backend_params):
105
    """Execute loop via simple sequential list comprehension."""
106
    ret = [fun(x, meta=meta) for x in iter]
107
    return ret
108
109
110
para_dict["none"] = _parallelize_none
111
112
113
def _parallelize_joblib(fun, iter, meta, backend, backend_params):
114
    """Parallelize loop via joblib Parallel."""
115
    from joblib import Parallel, delayed
116
117
    par_params = backend_params.copy()
118
    if "backend" not in par_params:
119
        # if user selects custom joblib backend but does not specify backend explicitly,
120
        # raise a ValueError
121
        if backend == "joblib":
122
            raise ValueError(
123
                '"joblib" was selected as first layer parallelization backend, '
124
                "but no backend string was "
125
                'passed in the backend parameters dict, e.g., "spark". '
126
                "Please specify a backend to joblib as a key-value pair "
127
                "in the backend_params arg or the backend:parallel:params config "
128
                'when using "joblib". '
129
                'For clarity, "joblib" should only be used for two-layer '
130
                "backend dispatch, where the first layer is joblib, "
131
                "and the second layer is a custom backend of joblib, e.g., spark. "
132
                "For first-party joblib backends, please use the backend string "
133
                'of sktime directly, e.g., by specifying "multiprocessing" or "loky".'
134
            )
135
        # in all other cases, we ensure the backend parameter is one of
136
        # "loky", "multiprocessing" or "threading", as passed via backend
137
        else:
138
            par_params["backend"] = backend
139
    elif backend != "joblib":
140
        par_params["backend"] = backend
141
142
    if "n_jobs" not in par_params:
143
        par_params["n_jobs"] = -1
144
145
    ret = Parallel(**par_params)(delayed(fun)(x, meta=meta) for x in iter)
146
    return ret
147
148
149
para_dict["joblib"] = _parallelize_joblib
150
151
152
def _parallelize_dask(fun, iter, meta, backend, backend_params):
153
    """Parallelize loop via dask."""
154
    from dask import compute, delayed
155
156
    lazy = [delayed(fun)(x, meta=meta) for x in iter]
157
    if backend == "dask":
158
        return compute(*lazy, **backend_params)
159
    else:
160
        return lazy
161
162
163
para_dict["dask"] = _parallelize_dask
164
165
166
def _parallelize_ray(fun, iter, meta, backend, backend_params):
167
    """Parallelize loop via ray."""
168
    import logging
169
    import warnings
170
171
    import ray
172
173
    par_params = backend_params.copy()
174
175
    # read the possible additional keys
176
    logger = logging.getLogger(par_params.get("logger_name", None))
177
    mute_warnings = par_params.get("mute_warnings", False)
178
    shutdown_ray = par_params.get("shutdown_ray", True)
179
180
    if "ray_remote_args" not in par_params.keys():
181
        par_params["ray_remote_args"] = {}
182
183
    @ray.remote  # pragma: no cover
184
    def _ray_execute_function(
185
        fun, params: dict, meta: dict, mute_warnings: bool = False
186
    ):
187
        if mute_warnings:
188
            warnings.filterwarnings("ignore")  # silence sktime warnings
189
        assert ray.is_initialized()
190
        result = fun(params, meta)
191
        return result
192
193
    if not ray.is_initialized():
194
        logger.info("Starting Ray Parallel")
195
        context = ray.init(**par_params["ray_remote_args"])
196
        logger.info(
197
            f"Ray initialized. Open dashboard at http://{context.dashboard_url}"
198
        )
199
200
    # this is to keep the order of results while still using wait to optimize runtime
201
    refs = [
202
        _ray_execute_function.remote(fun, x, meta, mute_warnings=mute_warnings)
203
        for x in iter
204
    ]
205
    res_dict = dict.fromkeys(refs)
206
207
    unfinished = refs
208
    while unfinished:
209
        finished, unfinished = ray.wait(unfinished, num_returns=1)
210
        res_dict[finished[0]] = ray.get(finished[0])
211
212
    if shutdown_ray:
213
        ray.shutdown()
214
215
    res = [res_dict[ref] for ref in refs]
216
    return res
217
218
219
para_dict["ray"] = _parallelize_ray
220
221
222
# list of backends where we skip tests during CI
223
SKIP_FIXTURES = [
224
    "ray",  # unstable, sporadic crashes in CI, see bug 8149
225
]
226
227
228
def _get_parallel_test_fixtures(naming="estimator"):
229
    """Return fixtures for parallelization tests.
230
231
    Returns a list of parameter fixtures, where each fixture
232
    is a dict with keys "backend" and "backend_params".
233
234
    Parameters
235
    ----------
236
    naming : str, optional
237
        naming convention for the parameters, one of
238
239
        "estimator": for use in estimator constructors,
240
        ``backend`` and ``backend_params``
241
        "config": for use in ``set_config``,
242
        ``backend:parallel`` and ``backend:parallel:params``
243
244
    Returns
245
    -------
246
    fixtures : list of dict
247
        list of backend parameter fixtures
248
        keys depend on ``naming`` parameter, see above
249
        either ``backend`` and ``backend_params`` (``naming="estimator"``),
250
        or ``backend:parallel`` and ``backend:parallel:params`` (``naming="config"``)
251
        values are backend strings and backend parameter dicts
252
        only backends that are available in the environment are included
253
    """
254
    from skbase.utils.dependencies import _check_soft_dependencies
255
256
    fixtures = []
257
258
    # test no parallelization
259
    fixtures.append({"backend": "None", "backend_params": {}})
260
261
    # test joblib backends
262
    for backend in ["loky", "multiprocessing", "threading"]:
263
        fixtures.append({"backend": backend, "backend_params": {}})
264
        fixtures.append({"backend": backend, "backend_params": {"n_jobs": 2}})
265
        fixtures.append({"backend": backend, "backend_params": {"n_jobs": -1}})
266
267
    # test dask backends
268
    if _check_soft_dependencies("dask", severity="none"):
269
        fixtures.append({"backend": "dask", "backend_params": {}})
270
        fixtures.append({"backend": "dask", "backend_params": {"scheduler": "sync"}})
271
272
    # test ray backend
273
    """ TODO: faster ray test
274
    if _check_soft_dependencies("ray", severity="none"):
275
        import os
276
277
        fixtures.append(
278
            {
279
                "backend": "ray",
280
                "backend_params": {
281
                    "mute_warnings": True,
282
                    "ray_remote_args": {"num_cpus": os.cpu_count() - 1},
283
                },
284
            }
285
        )
286
287
    fixtures = [x for x in fixtures if x["backend"] not in SKIP_FIXTURES]
288
    # remove backends in SKIP_FIXTURES from fixtures
289
    """
290
    return fixtures
291