Commit bb5c7a43 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'misc_fixes' into 'NIFTy_5'

various small fixes

See merge request ift/nifty-dev!74
parents c929aad0 558e9549
...@@ -89,7 +89,6 @@ if __name__ == '__main__': ...@@ -89,7 +89,6 @@ if __name__ == '__main__':
data = ift.Field.from_global_data(d_space, data) data = ift.Field.from_global_data(d_space, data)
# Compute likelihood and Hamiltonian # Compute likelihood and Hamiltonian
position = ift.from_random('normal', lamb.position.domain)
likelihood = ift.PoissonianEnergy(lamb, data) likelihood = ift.PoissonianEnergy(lamb, data)
ic_cg = ift.GradientNormController(iteration_limit=50) ic_cg = ift.GradientNormController(iteration_limit=50)
ic_newton = ift.GradientNormController(name='Newton', iteration_limit=50, ic_newton = ift.GradientNormController(name='Newton', iteration_limit=50,
...@@ -103,4 +102,6 @@ if __name__ == '__main__': ...@@ -103,4 +102,6 @@ if __name__ == '__main__':
# Plot results # Plot results
result_sky = sky.at(H.position).value result_sky = sky.at(H.position).value
# FIXME PLOTTING ift.plot(result_sky)
ift.plot_finish()
# FIXME MORE PLOTTING
...@@ -262,6 +262,8 @@ def empty_like(a, dtype=None): ...@@ -262,6 +262,8 @@ def empty_like(a, dtype=None):
def vdot(a, b): def vdot(a, b):
tmp = np.array(np.vdot(a._data, b._data)) tmp = np.array(np.vdot(a._data, b._data))
if a._distaxis==-1:
return tmp[()]
res = np.empty((), dtype=tmp.dtype) res = np.empty((), dtype=tmp.dtype)
_comm.Allreduce(tmp, res, MPI.SUM) _comm.Allreduce(tmp, res, MPI.SUM)
return res[()] return res[()]
...@@ -309,6 +311,10 @@ def from_object(object, dtype, copy, set_locked): ...@@ -309,6 +311,10 @@ def from_object(object, dtype, copy, set_locked):
# algorithm. # algorithm.
def from_random(random_type, shape, dtype=np.float64, **kwargs): def from_random(random_type, shape, dtype=np.float64, **kwargs):
generator_function = getattr(Random, random_type) generator_function = getattr(Random, random_type)
if shape == ():
ldat = generator_function(dtype=dtype, shape=shape, **kwargs)
ldat = _comm.bcast(ldat)
return from_local_data(shape, ldat, distaxis=-1)
for i in range(ntask): for i in range(ntask):
lshape = list(shape) lshape = list(shape)
lshape[0] = _shareSize(shape[0], ntask, i) lshape[0] = _shareSize(shape[0], ntask, i)
......
...@@ -630,7 +630,6 @@ for op in ["__add__", "__radd__", ...@@ -630,7 +630,6 @@ for op in ["__add__", "__radd__",
tval = getattr(self._val, op)(other) tval = getattr(self._val, op)(other)
return Field(self._domain, tval) return Field(self._domain, tval)
raise TypeError("should not arrive here")
return NotImplemented return NotImplemented
return func2 return func2
setattr(Field, op, func(op)) setattr(Field, op, func(op))
......
...@@ -125,6 +125,7 @@ class MultiField(object): ...@@ -125,6 +125,7 @@ class MultiField(object):
@staticmethod @staticmethod
def full(domain, val): def full(domain, val):
domain = MultiDomain.make(domain)
return MultiField(domain, tuple(Field.full(dom, val) return MultiField(domain, tuple(Field.full(dom, val)
for dom in domain._domains)) for dom in domain._domains))
......
from __future__ import absolute_import, division, print_function
import numpy as np import numpy as np
import itertools import itertools
from ..compat import *
from .. import utilities from .. import utilities
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
......
...@@ -243,3 +243,12 @@ class Test_Functionality(unittest.TestCase): ...@@ -243,3 +243,12 @@ class Test_Functionality(unittest.TestCase):
assert_equal((f/f2).local_data, f.local_data/f2.local_data) assert_equal((f/f2).local_data, f.local_data/f2.local_data)
assert_equal((-f).local_data, -(f.local_data)) assert_equal((-f).local_data, -(f.local_data))
assert_equal(abs(f).local_data, abs(f.local_data)) assert_equal(abs(f).local_data, abs(f.local_data))
def test_emptydomain(self):
f = ift.Field.full((), 3.)
assert_equal(f.sum(), 3.)
assert_equal(f.prod(), 3.)
assert_equal(f.local_data, 3.)
assert_equal(f.local_data.shape, ())
assert_equal(f.local_data.size, 1)
assert_equal(f.vdot(f), 9.)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment