forked from t-bltg/mpi4py-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path08-matrix-matrix-product.py
executable file
·76 lines (53 loc) · 1.97 KB
/
08-matrix-matrix-product.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/env python
import numpy as np
from mpi4py import MPI
from time import time
# =====================
my_N = 6000
my_M = 6000
# =====================
NORTH = 0
SOUTH = 1
EAST = 2
WEST = 3
comm = MPI.COMM_WORLD
mpi_rows = int(np.floor(np.sqrt(comm.size)))
mpi_cols = comm.size // mpi_rows
if comm.rank == 0:
print( 'Creating a {:d} x {:d} processor grid...'.format(mpi_rows, mpi_cols) )
ccomm = comm.Create_cart( (mpi_rows, mpi_cols), periods=(True, True), reorder=True )
my_mpi_row, my_mpi_col = ccomm.Get_coords( ccomm.rank )
neigh = [0, 0, 0, 0]
neigh[NORTH], neigh[SOUTH] = ccomm.Shift(0, 1)
neigh[EAST], neigh[WEST] = ccomm.Shift(1, 1)
# Create matrices
my_A = np.random.normal(size=(my_N, my_M)).astype(np.float32)
my_B = np.random.normal(size=(my_N, my_M)).astype(np.float32)
my_C = np.zeros_like(my_A)
tile_A, tile_B = my_A, my_B
tile_A_, tile_B_ = np.empty_like(my_A), np.empty_like(my_A)
req = [None, None, None, None]
t0 = time()
for r in range(mpi_rows):
req[EAST] = ccomm.Isend(tile_A , neigh[EAST])
req[WEST] = ccomm.Irecv(tile_A_, neigh[WEST])
req[SOUTH] = ccomm.Isend(tile_B , neigh[SOUTH])
req[NORTH] = ccomm.Irecv(tile_B_, neigh[NORTH])
#t0 = time()
my_C += np.dot(tile_A, tile_B)
#t1 = time()
req[0].Waitall(req)
#t2 = time()
#print( 'Time computing {:6.2f} {:6.2f}'.foramt(t1-t0, t2-t1) )
comm.barrier()
t_total = time() - t0
t0 = time()
np.dot(tile_A, tile_B)
t_serial = time() - t0
if comm.rank == 0:
print("===============================")
print('computed (serial) {:d} x {:d} x {:d} in {:6.2f} seconds'.format(my_M, my_M, my_N, t_serial))
print('... expecting parallel computation to take {:6.2f} seconds'.format(mpi_rows * mpi_rows * mpi_cols * t_serial / comm.size))
print('computed (parallel) {:d} x {:d} x {:d} in {:6.2f} seconds'.format(mpi_rows * my_M, mpi_rows * my_M, mpi_cols * my_N, t_total))
# print '[%d] (%d,%d): %s' % (comm.rank, my_mpi_row, my_mpi_col, neigh)
comm.barrier()