Commit 558e9549 authored by Martin Reinecke's avatar Martin Reinecke

various small fixes

parent a2f53b17
......@@ -89,7 +89,6 @@ if __name__ == '__main__':
data = ift.Field.from_global_data(d_space, data)
# Compute likelihood and Hamiltonian
position = ift.from_random('normal', lamb.position.domain)
likelihood = ift.PoissonianEnergy(lamb, data)
ic_cg = ift.GradientNormController(iteration_limit=50)
ic_newton = ift.GradientNormController(name='Newton', iteration_limit=50,
......@@ -103,4 +102,6 @@ if __name__ == '__main__':
# Plot results
result_sky = sky.at(H.position).value
# FIXME PLOTTING
ift.plot(result_sky)
ift.plot_finish()
# FIXME MORE PLOTTING
......@@ -262,6 +262,8 @@ def empty_like(a, dtype=None):
def vdot(a, b):
tmp = np.array(np.vdot(a._data, b._data))
if a._distaxis==-1:
return tmp[()]
res = np.empty((), dtype=tmp.dtype)
_comm.Allreduce(tmp, res, MPI.SUM)
return res[()]
......@@ -309,6 +311,10 @@ def from_object(object, dtype, copy, set_locked):
# algorithm.
def from_random(random_type, shape, dtype=np.float64, **kwargs):
generator_function = getattr(Random, random_type)
if shape == ():
ldat = generator_function(dtype=dtype, shape=shape, **kwargs)
ldat = _comm.bcast(ldat)
return from_local_data(shape, ldat, distaxis=-1)
for i in range(ntask):
lshape = list(shape)
lshape[0] = _shareSize(shape[0], ntask, i)
......
......@@ -630,7 +630,6 @@ for op in ["__add__", "__radd__",
tval = getattr(self._val, op)(other)
return Field(self._domain, tval)
raise TypeError("should not arrive here")
return NotImplemented
return func2
setattr(Field, op, func(op))
......
......@@ -125,6 +125,7 @@ class MultiField(object):
@staticmethod
def full(domain, val):
domain = MultiDomain.make(domain)
return MultiField(domain, tuple(Field.full(dom, val)
for dom in domain._domains))
......
from __future__ import absolute_import, division, print_function
import numpy as np
import itertools
from ..compat import *
from .. import utilities
from .linear_operator import LinearOperator
from ..domain_tuple import DomainTuple
......
......@@ -243,3 +243,12 @@ class Test_Functionality(unittest.TestCase):
assert_equal((f/f2).local_data, f.local_data/f2.local_data)
assert_equal((-f).local_data, -(f.local_data))
assert_equal(abs(f).local_data, abs(f.local_data))
def test_emptydomain(self):
f = ift.Field.full((), 3.)
assert_equal(f.sum(), 3.)
assert_equal(f.prod(), 3.)
assert_equal(f.local_data, 3.)
assert_equal(f.local_data.shape, ())
assert_equal(f.local_data.size, 1)
assert_equal(f.vdot(f), 9.)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment