Total Complexity | 1 |
Total Lines | 30 |
Duplicated Lines | 0 % |
Changes | 0 |
1 | #!/usr/bin/env python |
||
2 | |||
3 | from map_parallel import starmap_parallel |
||
4 | |||
5 | from mpi4py import MPI |
||
6 | |||
7 | comm = MPI.COMM_WORLD |
||
8 | rank = comm.Get_rank() |
||
9 | |||
10 | n_args = 3 |
||
11 | # a not so small prime no. > 8 |
||
12 | n_jobs = 17 |
||
13 | |||
14 | ARGS = [[2 * i * i + 3 * j * j + 5 * i * j for j in range(n_args)] for i in range(n_jobs)] |
||
15 | |||
16 | |||
17 | def f(x, y, z): |
||
18 | return x * x + y * y - z * z |
||
19 | |||
20 | |||
21 | if __name__ == "__main__": |
||
22 | args = list(map(list, zip(*ARGS))) |
||
23 | truth = list(map(f, *args)) |
||
24 | res = starmap_parallel(f, ARGS, mode='mpi_simple', return_results=True) |
||
25 | |||
26 | if rank == 0: |
||
27 | if res != truth: |
||
28 | print(res, truth) |
||
29 | raise AssertionError |
||
30 |