Commit 6f1b2a85 authored by Sebastian Ohlmann's avatar Sebastian Ohlmann

add iterator over global indices to python wrapper

parent 8ba73f1b
......@@ -327,3 +327,27 @@ class DistributedMatrix:
# this could be done more efficiently with a gather
self.processor_layout.comm.Allreduce(temporary, row, op=MPI.SUM)
return row
def global_indices(self):
"""Return iterator over global indices of matrix.
Use together with set_data_global_index and get_data_global_index.
"""
for local_row in range(self.na_rows):
for local_col in range(self.na_cols):
yield self.get_global_index(local_row, local_col)
def set_data_for_global_index(self, global_row, global_col, value):
"""Set value of matrix at global coordinates"""
if self.is_local_index(global_row, global_col):
local_row, local_col = self.get_local_index(global_row, global_col)
self.data[local_row, local_col] = value
def get_data_for_global_index(self, global_row, global_col):
"""Get value of matrix at global coordinates"""
if self.is_local_index(global_row, global_col):
local_row, local_col = self.get_local_index(global_row, global_col)
return self.data[local_row, local_col]
else:
raise ValueError('Index out of bounds: global row {:d}, '
'global col {:d}'.format(global_row, global_col))
......@@ -376,3 +376,29 @@ def test_accessing_matrix(na, nev, nblk):
assert(np.allclose(column, matrix[:, index]))
row = a.get_row(index)
assert(np.allclose(row, matrix[index, :]))
@pytest.mark.parametrize("na,nev,nblk", parameter_list)
def test_global_index_iterator(na, nev, nblk):
import numpy as np
from pyelpa import ProcessorLayout, DistributedMatrix
for dtype in [np.float64, np.complex128]:
a = DistributedMatrix.from_comm_world(na, nev, nblk, dtype=dtype)
for i, j in a.global_indices():
assert(a.is_local_index(i, j))
@pytest.mark.parametrize("na,nev,nblk", parameter_list)
def test_global_index_access(na, nev, nblk):
import numpy as np
from pyelpa import ProcessorLayout, DistributedMatrix
for dtype in [np.float64, np.complex128]:
a = DistributedMatrix.from_comm_world(na, nev, nblk, dtype=dtype)
for i, j in a.global_indices():
x = dtype(i*j)
a.set_data_for_global_index(i, j, x)
for i, j in a.global_indices():
x = a.get_data_for_global_index(i, j)
assert(np.isclose(x, i*j))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment