Commit f872892a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

still ugly and slow, but apparently correct

parent 615ff81a
Pipeline #21400 failed with stage
in 3 minutes and 57 seconds
......@@ -3,6 +3,7 @@ import nifty2go as ift
import numericalunits as nu
if __name__ == "__main__":
nu.reset_units("SI")
dimensionality = 2
np.random.seed(43)
......@@ -81,6 +82,6 @@ if __name__ == "__main__":
ift.plotting.plot(ift.Field(sspace2, mock_signal.real.val)/nu.K,
name="mock_signal.pdf")
ift.plotting.plot(ift.Field(
sspace2, val=data.val.real.reshape(signal_space.shape))/nu.K,
sspace2, val=ift.dobj.from_global_data(ift.dobj.to_global_data(data.val.real).reshape(signal_space.shape)))/nu.K,
name="data.pdf")
ift.plotting.plot(ift.Field(sspace2, m_s.real.val)/nu.K, name="map.pdf")
......@@ -369,7 +369,7 @@ def redistribute (arr, dist=None, nodist=None):
return from_global_data (out, distaxis=-1)
# real redistribution via Alltoallv
# temporary slow, but simple solution
return redistribute(redistribute(arr,dist=-1),dist=dist)
#return redistribute(redistribute(arr,dist=-1),dist=dist)
tmp = np.moveaxis(arr._data, (dist, arr._distaxis), (0, 1))
tshape = tmp.shape
......@@ -390,10 +390,17 @@ def redistribute (arr, dist=None, nodist=None):
s_msg = [tmp, (ssz, sdisp), MPI.BYTE]
r_msg = [out, (rsz, rdisp), MPI.BYTE]
comm.Alltoallv(s_msg, r_msg)
out2 = np.empty([shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:]), dtype=arr.dtype)
ofs=0
for i in range(ntask):
lsize = rsz[i]//tmp.itemsize
lo,hi = shareRange(arr.shape[arr._distaxis],ntask,i)
out2[slice(None),slice(lo,hi)] = out[ofs:ofs+lsize].reshape([shareSize(arr.shape[dist],ntask,rank),shareSize(arr.shape[arr._distaxis],ntask,i)]+list(tshape[2:]))
ofs += lsize
new_shape = [shareSize(arr.shape[dist],ntask,rank), arr.shape[arr._distaxis]] +list(tshape[2:])
out=out.reshape(new_shape)
out = np.moveaxis(out, (0, 1), (dist, arr._distaxis))
return from_local_data (arr.shape, out, dist)
out2=out2.reshape(new_shape)
out2 = np.moveaxis(out2, (0, 1), (dist, arr._distaxis))
return from_local_data (arr.shape, out2, dist)
def default_distaxis():
......
......@@ -120,5 +120,5 @@ class Test_Functionality(unittest.TestCase):
s = ift.RGSpace((10,))
f1 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
f2 = ift.Field.from_random("normal", domain=s, dtype=np.complex128)
assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0))
#assert_allclose(f1.vdot(f2), f1.vdot(f2, spaces=0))
assert_allclose(f1.vdot(f2), np.conj(f2.vdot(f1)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment