@@ 172-197 (lines=26) @@ | ||
169 | return [] |
|
170 | ||
171 | ||
172 | def _starmap_parallel_mpi_simple( |
|
173 | f: Callable, |
|
174 | args: list, |
|
175 | return_results: bool = True, |
|
176 | **kwargs, |
|
177 | ): |
|
178 | from mpi4py import MPI |
|
179 | ||
180 | comm = MPI.COMM_WORLD |
|
181 | size = comm.Get_size() |
|
182 | rank = comm.Get_rank() |
|
183 | ||
184 | if args: |
|
185 | local_args = args[rank::size] |
|
186 | res = list(starmap(f, local_args)) |
|
187 | ||
188 | if return_results: |
|
189 | res = comm.gather(res, root=0) |
|
190 | if rank == 0: |
|
191 | all_res = [] |
|
192 | for i in range(len(args)): |
|
193 | local_rank = i % size |
|
194 | local_i = i // size |
|
195 | all_res.append(res[local_rank][local_i]) |
|
196 | return all_res |
|
197 | return [] |
|
198 | ||
199 | ||
200 | _map_parallel_func: Dict[str, Callable] = { |
|
@@ 144-169 (lines=26) @@ | ||
141 | return [] |
|
142 | ||
143 | ||
144 | def _map_parallel_mpi_simple( |
|
145 | f: Callable, |
|
146 | *args, |
|
147 | return_results: bool = True, |
|
148 | **kwargs, |
|
149 | ) -> list: |
|
150 | from mpi4py import MPI |
|
151 | ||
152 | comm = MPI.COMM_WORLD |
|
153 | size = comm.Get_size() |
|
154 | rank = comm.Get_rank() |
|
155 | ||
156 | if args: |
|
157 | local_args = [arg[rank::size] for arg in args] |
|
158 | res = list(map(f, *local_args)) |
|
159 | ||
160 | if return_results: |
|
161 | res = comm.gather(res, root=0) |
|
162 | if rank == 0: |
|
163 | all_res = [] |
|
164 | for i in range(len(args[0])): |
|
165 | local_rank = i % size |
|
166 | local_i = i // size |
|
167 | all_res.append(res[local_rank][local_i]) |
|
168 | return all_res |
|
169 | return [] |
|
170 | ||
171 | ||
172 | def _starmap_parallel_mpi_simple( |