Commit 4067aa04 authored by Martin Reinecke's avatar Martin Reinecke

bug fixes

parent 3081351a
......@@ -262,6 +262,8 @@ def empty_like(a, dtype=None):
def vdot(a, b):
tmp = np.array(np.vdot(a._data, b._data))
if a._distaxis==-1:
return tmp[()]
res = np.empty((), dtype=tmp.dtype)
_comm.Allreduce(tmp, res, MPI.SUM)
return res[()]
......@@ -309,6 +311,10 @@ def from_object(object, dtype, copy, set_locked):
# algorithm.
def from_random(random_type, shape, dtype=np.float64, **kwargs):
generator_function = getattr(Random, random_type)
if shape == ():
ldat = generator_function(dtype=dtype, shape=shape, **kwargs)
ldat = _comm.bcast(ldat)
return from_local_data(shape, ldat, distaxis=-1)
for i in range(ntask):
lshape = list(shape)
lshape[0] = _shareSize(shape[0], ntask, i)
......
......@@ -31,7 +31,7 @@ class VdotOperator(LinearOperator):
def __init__(self, field):
super(VdotOperator, self).__init__()
self._field = field
self._target = DomainTuple.make(UnstructuredDomain(1))
self._target = DomainTuple.make(())
@property
def domain(self):
......@@ -49,4 +49,4 @@ class VdotOperator(LinearOperator):
self._check_input(x, mode)
if mode == self.TIMES:
return full(self._target, self._field.vdot(x))
return self._field*x.to_global_data()[()]
return self._field*x.local_data[()]
......@@ -243,3 +243,13 @@ class Test_Functionality(unittest.TestCase):
assert_equal((f/f2).local_data, f.local_data/f2.local_data)
assert_equal((-f).local_data, -(f.local_data))
assert_equal(abs(f).local_data, abs(f.local_data))
def test_emptydomain(self):
f = ift.Field.full((), 3.)
assert_equal(f.sum(), 3.)
assert_equal(f.prod(), 3.)
assert_equal(f.local_data, 3.)
assert_equal(f.local_data.shape, ())
assert_equal(f.local_data.size, 1)
assert_equal(f.vdot(f), 9.)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment