Skip to content
Snippets Groups Projects

Mf imaging

Merged Simon Ding requested to merge mf_imaging into master
2 unresolved threads
Files
2
+ 31
7
@@ -39,6 +39,9 @@ class AllreduceSum(ift.Operator):
@@ -39,6 +39,9 @@ class AllreduceSum(ift.Operator):
class SliceSum(ift.Operator):
class SliceSum(ift.Operator):
 
"""
 
Sum Operator that slices along the first axis of the input array and computes the sum in parallel using MPI.
 
"""
def __init__(self, oplist, index_low, parallel_space, comm):
def __init__(self, oplist, index_low, parallel_space, comm):
# FIXME if oplist contains only linear operators instantiate
# FIXME if oplist contains only linear operators instantiate
# SliceSumLinear instead
# SliceSumLinear instead
@@ -93,6 +96,8 @@ class SliceSumLinear(ift.LinearOperator):
@@ -93,6 +96,8 @@ class SliceSumLinear(ift.LinearOperator):
self._oplist = oplist
self._oplist = oplist
self._comm = comm
self._comm = comm
self._lo = index_low
self._lo = index_low
 
local_nwork = [len(oplist)] if comm is None else comm.allgather(len(oplist))
 
self._nwork = sum(local_nwork)
def apply(self, x, mode):
def apply(self, x, mode):
self._check_input(x, mode)
self._check_input(x, mode)
@@ -106,7 +111,7 @@ class SliceSumLinear(ift.LinearOperator):
@@ -106,7 +111,7 @@ class SliceSumLinear(ift.LinearOperator):
)
)
else:
else:
arr = array_allgather(
arr = array_allgather(
[op.adjoint(x).val for op in self._oplist], self._comm
[op.adjoint(x).val for op in self._oplist], self._comm, self._nwork
)
)
return ift.makeField(self.domain, arr)
return ift.makeField(self.domain, arr)
@@ -123,7 +128,6 @@ class SliceLinear(ift.EndomorphicOperator):
@@ -123,7 +128,6 @@ class SliceLinear(ift.EndomorphicOperator):
self._comm = comm
self._comm = comm
self._lo = index_low
self._lo = index_low
local_nwork = [len(oplist)] if comm is None else comm.allgather(len(oplist))
local_nwork = [len(oplist)] if comm is None else comm.allgather(len(oplist))
size, rank, _ = ift.utilities.get_MPI_params_from_comm(comm)
self._nwork = sum(local_nwork)
self._nwork = sum(local_nwork)
def apply(self, x, mode):
def apply(self, x, mode):
@@ -134,6 +138,7 @@ class SliceLinear(ift.EndomorphicOperator):
@@ -134,6 +138,7 @@ class SliceLinear(ift.EndomorphicOperator):
for ii, op in enumerate(self._oplist)
for ii, op in enumerate(self._oplist)
],
],
self._comm,
self._comm,
 
self._nwork
)
)
return ift.makeField(self._domain, res)
return ift.makeField(self._domain, res)
@@ -143,7 +148,7 @@ class SliceLinear(ift.EndomorphicOperator):
@@ -143,7 +148,7 @@ class SliceLinear(ift.EndomorphicOperator):
for ii, op in enumerate(self._oplist):
for ii, op in enumerate(self._oplist):
with ift.random.Context(sseq[self._lo + ii]):
with ift.random.Context(sseq[self._lo + ii]):
local_samples.append(op.draw_sample(from_inverse).val)
local_samples.append(op.draw_sample(from_inverse).val)
res = array_allgather(local_samples, self._comm)
res = array_allgather(local_samples, self._comm, self._nwork)
return ift.makeField(self._domain, res)
return ift.makeField(self._domain, res)
@@ -189,9 +194,28 @@ def _get_global_unique(lst, f, comm):
@@ -189,9 +194,28 @@ def _get_global_unique(lst, f, comm):
return cap
return cap
def array_allgather(arrs, comm):
def array_allgather(arrs, comm, nwork):
if comm is None:
if comm is None:
fulllst = [arrs]
full_lst = np.array(arrs)
else:
else:
fulllst = comm.allgather(arrs)
from mpi4py import MPI
return np.array([aa for cc in fulllst for aa in cc])
size = comm.Get_size()
 
send_buf = np.array(arrs)
 
recv_buffer_shape = (nwork,) + send_buf.shape[1:]
 
 
full_lst = np.empty(recv_buffer_shape)
 
 
send_count = []
 
displacement = [0]
 
 
for rank in range(size):
 
lo, hi = ift.utilities.shareRange(nwork, size, rank)
 
n_work_per_rank = hi - lo
 
send_count_per_rank = np.prod((n_work_per_rank,) + send_buf.shape[1:])
 
send_count.append(send_count_per_rank)
 
 
if rank != size - 1:
 
displacement.append(send_count_per_rank + displacement[rank])
 
 
comm.Allgatherv([send_buf, MPI.DOUBLE], [full_lst, tuple(send_count), tuple(displacement), MPI.DOUBLE])
 
return full_lst
Loading