Test Failed
Push — master ( c6c275...59c999 )
by Kolen
04:51
created

map_parallel._map_parallel_mpi_simple()   B

Complexity

Conditions 5

Size

Total Lines 26
Code Lines 22

Duplication

Lines 26
Ratio 100 %

Importance

Changes 0
Metric Value
cc 5
eloc 22
nop 4
dl 26
loc 26
rs 8.8853
c 0
b 0
f 0
1
from __future__ import annotations
2
3
__version__ = '0.1.1'
4
5
from functools import partial
6
from itertools import starmap
7
from typing import TYPE_CHECKING
8
9
if TYPE_CHECKING:
10
    from typing import Optional, Callable, Dict
11
    from collections.abc import Iterable
12
13
14
def _starfunc(f: Callable, x):
15
    '''return f(*x)
16
    '''
17
    return f(*x)
18
19
20
def _map_parallel_multiprocessing(
21
    f: Callable,
22
    *args,
23
    processes: Optional[int] = None,
24
    return_results: bool = True,
25
) -> list:
26
    from concurrent.futures import ProcessPoolExecutor
27
28
    with ProcessPoolExecutor(max_workers=processes) as process_pool_executor:
29
        res = process_pool_executor.map(f, *args)
30
        if return_results:
31
            return list(res)
32
        else:
33
            return []
34
35
36
def _starmap_parallel_multiprocessing(
37
    f: Callable,
38
    args: Iterable,
39
    processes: Optional[int] = None,
40
    return_results: bool = True,
41
) -> list:
42
    from concurrent.futures import ProcessPoolExecutor
43
44
    with ProcessPoolExecutor(max_workers=processes) as process_pool_executor:
45
        res = process_pool_executor.map(partial(_starfunc, f), args)
46
        if return_results:
47
            return list(res)
48
        else:
49
            return []
50
51
52
def _map_parallel_multithreading(
53
    f: Callable,
54
    *args,
55
    processes: Optional[int] = None,
56
    return_results: bool = True,
57
) -> list:
58
    from concurrent.futures import ThreadPoolExecutor
59
60
    with ThreadPoolExecutor(max_workers=processes) as thread_pool_executor:
61
        res = thread_pool_executor.map(f, *args)
62
        if return_results:
63
            return list(res)
64
        else:
65
            return []
66
67
68
def _starmap_parallel_multithreading(
69
    f: Callable,
70
    args: Iterable,
71
    processes: Optional[int] = None,
72
    return_results: bool = True,
73
) -> list:
74
    from concurrent.futures import ThreadPoolExecutor
75
76
    with ThreadPoolExecutor(max_workers=processes) as thread_pool_executor:
77
        res = thread_pool_executor.map(partial(_starfunc, f), args)
78
        if return_results:
79
            return list(res)
80
        else:
81
            return []
82
83
84
def _map_parallel_dask(
85
    f: Callable,
86
    *args,
87
    processes: Optional[int] = None,
88
    return_results: bool = True,
89
) -> list:
90
    from dask.distributed import Client
91
    from dask.distributed import LocalCluster
92
93
    cluster = LocalCluster(n_workers=processes, dashboard_address=None)
94
    client = Client(cluster)
95
    if return_results:
96
        return [future.result() for future in client.map(f, *args)]
97
    else:
98
        for future in client.map(f, *args):
99
            future.result()
100
        return []
101
102
103
def _starmap_parallel_dask(
104
    f: Callable,
105
    args: Iterable,
106
    processes: Optional[int] = None,
107
    return_results: bool = True,
108
) -> list:
109
    from dask.distributed import Client
110
    from dask.distributed import LocalCluster
111
112
    cluster = LocalCluster(n_workers=processes, dashboard_address=None)
113
    client = Client(cluster)
114
    if return_results:
115
        return [future.result() for future in client.map(partial(_starfunc, f), args)]
116
    else:
117
        for future in client.map(partial(_starfunc, f), args):
118
            future.result()
119
        return []
120
121
122
def _map_parallel_mpi(f: Callable, *args, return_results: bool = True, **kwargs) -> list:
123
    from mpi4py.futures import MPIPoolExecutor
124
125
    with MPIPoolExecutor() as mpi_pool_executor:
126
        res = mpi_pool_executor.map(f, *args)
127
        if return_results:
128
            return list(res)
129
        else:
130
            return []
131
132
133
def _starmap_parallel_mpi(f: Callable, args: Iterable, return_results: bool = True, **kwargs) -> list:
134
    from mpi4py.futures import MPIPoolExecutor
135
136
    with MPIPoolExecutor() as mpi_pool_executor:
137
        res = mpi_pool_executor.starmap(f, args)
138
        if return_results:
139
            return list(res)
140
        else:
141
            return []
142
143
144 View Code Duplication
def _map_parallel_mpi_simple(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
145
    f: Callable,
146
    *args,
147
    return_results: bool = True,
148
    **kwargs,
149
) -> list:
150
    from mpi4py import MPI
151
152
    comm = MPI.COMM_WORLD
153
    size = comm.Get_size()
154
    rank = comm.Get_rank()
155
156
    if args:
157
        local_args = [arg[rank::size] for arg in args]
158
        res = list(map(f, *local_args))
159
160
        if return_results:
161
            res = comm.gather(res, root=0)
162
            if rank == 0:
163
                all_res = []
164
                for i in range(len(args[0])):
165
                    local_rank = i % size
166
                    local_i = i // size
167
                    all_res.append(res[local_rank][local_i])
168
                return all_res
169
    return []
170
171
172 View Code Duplication
def _starmap_parallel_mpi_simple(
0 ignored issues
show
Duplication introduced by
This code seems to be duplicated in your project.
Loading history...
173
    f: Callable,
174
    args: list,
175
    return_results: bool = True,
176
    **kwargs,
177
):
178
    from mpi4py import MPI
179
180
    comm = MPI.COMM_WORLD
181
    size = comm.Get_size()
182
    rank = comm.Get_rank()
183
184
    if args:
185
        local_args = args[rank::size]
186
        res = list(starmap(f, local_args))
187
188
        if return_results:
189
            res = comm.gather(res, root=0)
190
            if rank == 0:
191
                all_res = []
192
                for i in range(len(args)):
193
                    local_rank = i % size
194
                    local_i = i // size
195
                    all_res.append(res[local_rank][local_i])
196
                return all_res
197
    return []
198
199
200
_map_parallel_func: Dict[str, Callable] = {
0 ignored issues
show
introduced by
The variable Callable does not seem to be defined in case TYPE_CHECKING on line 9 is False. Are you sure this can never be the case?
Loading history...
201
    'multiprocessing': _map_parallel_multiprocessing,
202
    'multithreading': _map_parallel_multithreading,
203
    'dask': _map_parallel_dask,
204
    'mpi': _map_parallel_mpi,
205
    'mpi_simple': _map_parallel_mpi_simple,
206
}
207
208
209
_starmap_parallel_func: Dict[str, Callable] = {
210
    'multiprocessing': _starmap_parallel_multiprocessing,
211
    'multithreading': _starmap_parallel_multithreading,
212
    'dask': _starmap_parallel_dask,
213
    'mpi': _starmap_parallel_mpi,
214
    'mpi_simple': _starmap_parallel_mpi_simple,
215
}
216
217
218
def map_parallel(
219
    f: Callable,
220
    *args,
221
    processes: Optional[int] = None,
222
    mode: str = 'multiprocessing',
223
    return_results: bool = True,
224
) -> list:
225
    '''equiv to `map(f, *args)` but in parallel
226
227
    :param str mode: backend for parallelization
228
        - multiprocessing: using multiprocessing from standard library
229
        - multithreading: using multithreading from standard library
230
        - dask: using dask.distributed
231
        - mpi: using mpi4py.futures. May not work depending on your MPI vendor
232
        - mpi_simple: using mpi4py with simple scheduling that divides works into equal chunks
233
        - serial: using map
234
    :param int processes: no. of parallel processes
235
236
    (in the case of mpi, it is determined by mpiexec/mpirun args)
237
238
    :param bool return_results: (Only affects mode == 'mpi_simple') if True, return results in rank 0.
239
    '''
240
    if processes is None or processes > 1:
241
        try:
242
            return _map_parallel_func[mode](f, *args, processes=processes, return_results=return_results)
243
        except KeyError:
244
            pass
245
    return list(map(f, *args))
246
247
248
def starmap_parallel(
249
    f: Callable,
250
    args: Iterable,
251
    processes: Optional[int] = None,
252
    mode: str = 'multiprocessing',
253
    return_results: bool = True,
254
) -> list:
255
    '''equiv to `starmap(f, args)` but in parallel
256
257
    :param str mode: backend for parallelization
258
        - multiprocessing: using multiprocessing from standard library
259
        - multithreading: using multithreading from standard library
260
        - dask: using dask.distributed
261
        - mpi: using mpi4py.futures. May not work depending on your MPI vendor
262
        - mpi_simple: using mpi4py with simple scheduling that divides works into equal chunks
263
        - serial: using map
264
    :param int processes: no. of parallel processes
265
266
    (in the case of mpi, it is determined by mpiexec/mpirun args)
267
268
    :param bool return_results: (Only affects mode == 'mpi_simple') if True, return results in rank 0.
269
    '''
270
    if processes is None or processes > 1:
271
        try:
272
            return _starmap_parallel_func[mode](f, args, processes=processes, return_results=return_results)
273
        except KeyError:
274
            pass
275
    return list(starmap(f, args))
276