Code Duplication    Length = 26-26 lines in 2 locations

src/map_parallel/__init__.py 2 locations

@@ 172-197 (lines=26) @@
169
    return []
170
171
172
def _starmap_parallel_mpi_simple(
173
    f: Callable,
174
    args: list,
175
    return_results: bool = True,
176
    **kwargs,
177
):
178
    from mpi4py import MPI
179
180
    comm = MPI.COMM_WORLD
181
    size = comm.Get_size()
182
    rank = comm.Get_rank()
183
184
    if args:
185
        local_args = args[rank::size]
186
        res = list(starmap(f, local_args))
187
188
        if return_results:
189
            res = comm.gather(res, root=0)
190
            if rank == 0:
191
                all_res = []
192
                for i in range(len(args)):
193
                    local_rank = i % size
194
                    local_i = i // size
195
                    all_res.append(res[local_rank][local_i])
196
                return all_res
197
    return []
198
199
200
_map_parallel_func: Dict[str, Callable] = {
@@ 144-169 (lines=26) @@
141
            return []
142
143
144
def _map_parallel_mpi_simple(
145
    f: Callable,
146
    *args,
147
    return_results: bool = True,
148
    **kwargs,
149
) -> list:
150
    from mpi4py import MPI
151
152
    comm = MPI.COMM_WORLD
153
    size = comm.Get_size()
154
    rank = comm.Get_rank()
155
156
    if args:
157
        local_args = [arg[rank::size] for arg in args]
158
        res = list(map(f, *local_args))
159
160
        if return_results:
161
            res = comm.gather(res, root=0)
162
            if rank == 0:
163
                all_res = []
164
                for i in range(len(args[0])):
165
                    local_rank = i % size
166
                    local_i = i // size
167
                    all_res.append(res[local_rank][local_i])
168
                return all_res
169
    return []
170
171
172
def _starmap_parallel_mpi_simple(