Commit cf9d841a authored by Martin Reinecke's avatar Martin Reinecke
Browse files

better, but not good at all

parent 640ed6ea
...@@ -506,37 +506,43 @@ class Field(object): ...@@ -506,37 +506,43 @@ class Field(object):
def s_any(self): def s_any(self):
return self._val.any() return self._val.any()
# def min(self, spaces=None): def min(self, spaces=None):
# """Determines the minimum over the sub-domains given by `spaces`. """Determines the minimum over the sub-domains given by `spaces`.
#
# Parameters Parameters
# ---------- ----------
# spaces : None, int or tuple of int (default: None) spaces : None, int or tuple of int (default: None)
# The operation is only carried out over the sub-domains in this The operation is only carried out over the sub-domains in this
# tuple. If None, it is carried out over all sub-domains. tuple. If None, it is carried out over all sub-domains.
#
# Returns Returns
# ------- -------
# Field Field
# The result of the operation. The result of the operation.
# """ """
# return self._contraction_helper('min', spaces) return self._contraction_helper('min', spaces)
#
# def max(self, spaces=None): def s_min(self):
# """Determines the maximum over the sub-domains given by `spaces`. return self._val.min()
#
# Parameters def max(self, spaces=None):
# ---------- """Determines the maximum over the sub-domains given by `spaces`.
# spaces : None, int or tuple of int (default: None)
# The operation is only carried out over the sub-domains in this Parameters
# tuple. If None, it is carried out over all sub-domains. ----------
# spaces : None, int or tuple of int (default: None)
# Returns The operation is only carried out over the sub-domains in this
# ------- tuple. If None, it is carried out over all sub-domains.
# Field
# The result of the operation. Returns
# """ -------
# return self._contraction_helper('max', spaces) Field
The result of the operation.
"""
return self._contraction_helper('max', spaces)
def s_max(self):
return self._val.max()
def mean(self, spaces=None): def mean(self, spaces=None):
"""Determines the mean over the sub-domains given by `spaces`. """Determines the mean over the sub-domains given by `spaces`.
......
...@@ -102,7 +102,7 @@ class Linearization(Operator): ...@@ -102,7 +102,7 @@ class Linearization(Operator):
----- -----
Only available if target is a scalar Only available if target is a scalar
""" """
return self._jac.adjoint_times(Field.scalar(1.)) return self._jac.adjoint_times(Field.scalar(1.).mult)
@property @property
def want_metric(self): def want_metric(self):
......
...@@ -190,7 +190,7 @@ class MetricGaussianKL(Energy): ...@@ -190,7 +190,7 @@ class MetricGaussianKL(Energy):
v, g = None, None v, g = None, None
if len(self._local_samples) == 0: # hack if there are too many MPI tasks if len(self._local_samples) == 0: # hack if there are too many MPI tasks
tmp = self._hamiltonian(self._lin) tmp = self._hamiltonian(self._lin)
v = 0. * tmp.val.val v = 0. * tmp.val.sing.val
g = 0. * tmp.gradient g = 0. * tmp.gradient
else: else:
for s in self._local_samples: for s in self._local_samples:
...@@ -198,12 +198,12 @@ class MetricGaussianKL(Energy): ...@@ -198,12 +198,12 @@ class MetricGaussianKL(Energy):
if self._mirror_samples: if self._mirror_samples:
tmp = tmp + self._hamiltonian(self._lin-s) tmp = tmp + self._hamiltonian(self._lin-s)
if v is None: if v is None:
v = tmp.val.val_rw() v = tmp.val.sing.val_rw()
g = tmp.gradient g = tmp.gradient
else: else:
v += tmp.val.val v += tmp.val.sing.val
g = g + tmp.gradient g = g + tmp.gradient
self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples self._val = (_np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples)
self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
self._metric = None self._metric = None
self._sampdt = lh_sampling_dtype self._sampdt = lh_sampling_dtype
......
...@@ -26,12 +26,12 @@ from .minimizer import Minimizer ...@@ -26,12 +26,12 @@ from .minimizer import Minimizer
def _multiToArray(fld): def _multiToArray(fld):
szall = sum(2*v.size if iscomplextype(v.dtype) else v.size szall = sum(2*v.domain.size if iscomplextype(v.dtype) else v.domain.size
for v in fld.values()) for v in fld.values())
res = np.empty(szall, dtype=np.float64) res = np.empty(szall, dtype=np.float64)
ofs = 0 ofs = 0
for val in fld.values(): for val in fld.values():
sz2 = 2*val.size if iscomplextype(val.dtype) else val.size sz2 = 2*val.domain.size if iscomplextype(val.dtype) else val.domain.size
locdat = val.val.reshape(-1) locdat = val.val.reshape(-1)
if iscomplextype(val.dtype): if iscomplextype(val.dtype):
locdat = locdat.view(locdat.real.dtype) locdat = locdat.view(locdat.real.dtype)
...@@ -58,11 +58,11 @@ def _toField(arr, template): ...@@ -58,11 +58,11 @@ def _toField(arr, template):
ofs = 0 ofs = 0
res = [] res = []
for v in template.values(): for v in template.values():
sz2 = 2*v.size if iscomplextype(v.dtype) else v.size sz2 = 2*v.domain.size if iscomplextype(v.dtype) else v.domain.size
locdat = arr[ofs:ofs+sz2].copy() locdat = arr[ofs:ofs+sz2].copy()
if iscomplextype(v.dtype): if iscomplextype(v.dtype):
locdat = locdat.view(np.complex128) locdat = locdat.view(np.complex128)
res.append(Field(v.domain, locdat.reshape(v.shape))) res.append(Field(v.domain, locdat.reshape(v.domain.shape)))
ofs += sz2 ofs += sz2
return MultiField(template.domain, tuple(res)) return MultiField(template.domain, tuple(res))
......
...@@ -218,6 +218,12 @@ class MultiField(Operator): ...@@ -218,6 +218,12 @@ class MultiField(Operator):
""" """
return utilities.my_sum(map(lambda v: v.s_sum(), self._val)) return utilities.my_sum(map(lambda v: v.s_sum(), self._val))
def s_min(self):
return min([v.s_min() for v in self._val])
def s_max(self):
return min([v.s_max() for v in self._val])
def __neg__(self): def __neg__(self):
return self._transform(lambda x: -x) return self._transform(lambda x: -x)
...@@ -243,6 +249,9 @@ class MultiField(Operator): ...@@ -243,6 +249,9 @@ class MultiField(Operator):
self._domain, self._domain,
tuple(self._val[i].where(iftrue[i], iffalse[i]) for i in range(ncomp))) tuple(self._val[i].where(iftrue[i], iffalse[i]) for i in range(ncomp)))
def weight(self, power):
return MultiField(self._domain, tuple(v.weight(power) for v in self._val))
def s_all(self): def s_all(self):
for v in self._val: for v in self._val:
if not v.s_all(): if not v.s_all():
......
...@@ -135,19 +135,19 @@ class FieldAdapter(LinearOperator): ...@@ -135,19 +135,19 @@ class FieldAdapter(LinearOperator):
from ..sugar import makeDomain from ..sugar import makeDomain
tmp = makeDomain(tgt) tmp = makeDomain(tgt)
if isinstance(tmp, DomainTuple): if isinstance(tmp, DomainTuple):
self._target = tmp self._target = tmp.mult
self._domain = MultiDomain.make({name: tmp}) self._domain = MultiDomain.make({name: tmp})
else: else:
self._domain = tmp[name] self._domain = tmp[name].mult
self._target = MultiDomain.make({name: tmp[name]}) self._target = MultiDomain.make({name: tmp[name]})
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if isinstance(x, MultiField): if isinstance(x, MultiField):
return x.values()[0] return x.values()[0].mult
else: else:
return MultiField(self._tgt(mode), (x,)) return MultiField(self._tgt(mode), (x[""],))
def __repr__(self): def __repr__(self):
s = 'FieldAdapter' s = 'FieldAdapter'
...@@ -237,6 +237,17 @@ def ducktape(left, right, name): ...@@ -237,6 +237,17 @@ def ducktape(left, right, name):
left = left.domain left = left.domain
elif left is not None: elif left is not None:
left = makeDomain(left) left = makeDomain(left)
def _simplify(dom):
if dom is None or isinstance(dom, DomainTuple):
return dom
if isinstance (dom, MultiDomain):
if len(dom) == 1 and dom.keys()[0] == "":
return dom.sing
return dom
left = _simplify(left)
right = _simplify(right)
if left is None: # need to infer left from right if left is None: # need to infer left from right
if isinstance(right, MultiDomain): if isinstance(right, MultiDomain):
left = right[name] left = right[name]
......
...@@ -19,10 +19,10 @@ import numpy as np ...@@ -19,10 +19,10 @@ import numpy as np
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..field import Field from ..field import Field
from .linear_operator import LinearOperator from .linear_operator import LinearOperator_s
class ValueInserter(LinearOperator): class ValueInserter(LinearOperator_s):
"""Inserts one value into a field which is zero otherwise. """Inserts one value into a field which is zero otherwise.
Parameters Parameters
...@@ -34,27 +34,27 @@ class ValueInserter(LinearOperator): ...@@ -34,27 +34,27 @@ class ValueInserter(LinearOperator):
""" """
def __init__(self, target, index): def __init__(self, target, index):
self._domain = DomainTuple.scalar_domain() self._domain_s = DomainTuple.scalar_domain()
self._target = DomainTuple.make(target) self._target_s = DomainTuple.make(target)
index = tuple(index) index = tuple(index)
if not all([ if not all([
isinstance(n, int) and n >= 0 and n < self.target.shape[i] isinstance(n, int) and n >= 0 and n < self._target_s.shape[i]
for i, n in enumerate(index) for i, n in enumerate(index)
]): ]):
raise TypeError raise TypeError
if not len(index) == len(self.target.shape): if not len(index) == len(self._target_s.shape):
raise ValueError raise ValueError
self._index = index self._index = index
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
# Check whether index is in bounds # Check whether index is in bounds
np.empty(self.target.shape)[self._index] np.empty(self._target_s.shape)[self._index]
def apply(self, x, mode): def _apply_s(self, x, mode):
self._check_input(x, mode) self._check_input_s(x, mode)
x = x.val x = x.val
if mode == self.TIMES: if mode == self.TIMES:
res = np.zeros(self.target.shape, dtype=x.dtype) res = np.zeros(self._target_s.shape, dtype=x.dtype)
res[self._index] = x res[self._index] = x
return Field(self._tgt(mode), res) return Field(self._tgt_s(mode), res)
else: else:
return Field.scalar(x[self._index]) return Field.scalar(x[self._index])
...@@ -98,7 +98,7 @@ def get_signal_variance(spec, space): ...@@ -98,7 +98,7 @@ def get_signal_variance(spec, space):
def _single_power_analyze(field, idx, binbounds): def _single_power_analyze(field, idx, binbounds):
power_domain = PowerSpace(field.domain[idx], binbounds) power_domain = PowerSpace(field.domain[idx], binbounds)
pd = PowerDistributor(field.domain, power_domain, idx) pd = PowerDistributor(field.domain, power_domain, idx)
return pd.adjoint_times(field.weight(1)).weight(-1) # divides by bin size return pd.adjoint_times(field.weight(1)).sing.weight(-1) # divides by bin size
# MR FIXME: this function is not well suited for analyzing more than one # MR FIXME: this function is not well suited for analyzing more than one
...@@ -496,7 +496,7 @@ def calculate_position(operator, output): ...@@ -496,7 +496,7 @@ def calculate_position(operator, output):
raise TypeError raise TypeError
if output.domain != operator.target: if output.domain != operator.target:
raise TypeError raise TypeError
cov = 1e-3*output.val.max()**2 cov = 1e-3*output.s_max()**2
invcov = ScalingOperator(output.domain, cov).inverse invcov = ScalingOperator(output.domain, cov).inverse
d = output + invcov.draw_sample(from_inverse=True) d = output + invcov.draw_sample(from_inverse=True)
lh = GaussianEnergy(d, invcov) @ operator lh = GaussianEnergy(d, invcov) @ operator
......
...@@ -76,21 +76,21 @@ def test_quadratic_minimization(minimizer, space): ...@@ -76,21 +76,21 @@ def test_quadratic_minimization(minimizer, space):
@pmp('space', spaces) @pmp('space', spaces)
def test_WF_curvature(space): def test_WF_curvature(space):
required_result = ift.full(space, 1.) required_result = ift.full(space, 1.).mult
s = ift.Field.from_random('uniform', domain=space) + 0.5 s = ift.Field.from_random('uniform', domain=space).mult + 0.5
S = ift.DiagonalOperator(s) S = ift.DiagonalOperator(s)
r = ift.Field.from_random('uniform', domain=space) r = ift.Field.from_random('uniform', domain=space).mult
R = ift.DiagonalOperator(r) R = ift.DiagonalOperator(r)
n = ift.Field.from_random('uniform', domain=space) + 0.5 n = ift.Field.from_random('uniform', domain=space).mult + 0.5
N = ift.DiagonalOperator(n) N = ift.DiagonalOperator(n)
all_diag = 1./s + r**2/n all_diag = 1./s + r**2/n
curv = ift.WienerFilterCurvature(R, N, S, iteration_controller=IC, curv = ift.WienerFilterCurvature(R, N, S, iteration_controller=IC,
iteration_controller_sampling=IC) iteration_controller_sampling=IC)
m = curv.inverse(required_result) m = curv.inverse(required_result)
assert_allclose( assert_allclose(
m.val, m.sing.val,
1./all_diag.val, 1./all_diag.sing.val,
rtol=1e-3, rtol=1e-3,
atol=1e-3) atol=1e-3)
curv.draw_sample() curv.draw_sample()
...@@ -98,7 +98,7 @@ def test_WF_curvature(space): ...@@ -98,7 +98,7 @@ def test_WF_curvature(space):
if len(space.shape) == 1: if len(space.shape) == 1:
R = ift.ValueInserter(space, [0]) R = ift.ValueInserter(space, [0])
n = ift.from_random('uniform', R.domain) + 0.5 n = ift.from_random('uniform', R.domain).mult + 0.5
N = ift.DiagonalOperator(n) N = ift.DiagonalOperator(n)
all_diag = 1./s + R(1/n) all_diag = 1./s + R(1/n)
curv = ift.WienerFilterCurvature(R.adjoint, N, S, curv = ift.WienerFilterCurvature(R.adjoint, N, S,
...@@ -106,8 +106,8 @@ def test_WF_curvature(space): ...@@ -106,8 +106,8 @@ def test_WF_curvature(space):
iteration_controller_sampling=IC) iteration_controller_sampling=IC)
m = curv.inverse(required_result) m = curv.inverse(required_result)
assert_allclose( assert_allclose(
m.val, m.sing.val,
1./all_diag.val, 1./all_diag.sing.val,
rtol=1e-3, rtol=1e-3,
atol=1e-3) atol=1e-3)
curv.draw_sample() curv.draw_sample()
...@@ -176,7 +176,7 @@ def test_rosenbrock(minimizer): ...@@ -176,7 +176,7 @@ def test_rosenbrock(minimizer):
@pmp('minimizer', minimizers + slow_minimizers) @pmp('minimizer', minimizers + slow_minimizers)
def test_gauss(minimizer): def test_gauss(minimizer):
space = ift.UnstructuredDomain((1,)) space = ift.UnstructuredDomain((1,))
starting_point = ift.Field.full(space, 3.) starting_point = ift.Field.full(space, 3.).mult
class ExpEnergy(ift.Energy): class ExpEnergy(ift.Energy):
def __init__(self, position): def __init__(self, position):
...@@ -184,19 +184,19 @@ def test_gauss(minimizer): ...@@ -184,19 +184,19 @@ def test_gauss(minimizer):
@property @property
def value(self): def value(self):
x = self.position.val[0] x = self.position.sing.val[0]
return -np.exp(-(x**2)) return -np.exp(-(x**2))
@property @property
def gradient(self): def gradient(self):
x = self.position.val[0] x = self.position.sing.val[0]
return ift.Field.full(self.position.domain, 2*x*np.exp(-(x**2))) return ift.full(self.position.domain, 2*x*np.exp(-(x**2)))
def apply_metric(self, x): def apply_metric(self, x):
p = self.position.val[0] p = self.position.sing.val[0]
v = (2 - 4*p*p)*np.exp(-p**2) v = (2 - 4*p*p)*np.exp(-p**2)
return ift.DiagonalOperator( return ift.DiagonalOperator(
ift.Field.full(self.position.domain, v))(x) ift.full(self.position.domain, v))(x)
try: try:
minimizer = eval(minimizer) minimizer = eval(minimizer)
...@@ -207,13 +207,13 @@ def test_gauss(minimizer): ...@@ -207,13 +207,13 @@ def test_gauss(minimizer):
raise SkipTest raise SkipTest
assert_equal(convergence, IC.CONVERGED) assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.val, 0., atol=1e-3) assert_allclose(energy.position.sing.val, 0., atol=1e-3)
@pmp('minimizer', minimizers + newton_minimizers + slow_minimizers) @pmp('minimizer', minimizers + newton_minimizers + slow_minimizers)
def test_cosh(minimizer): def test_cosh(minimizer):
space = ift.UnstructuredDomain((1,)) space = ift.UnstructuredDomain((1,))
starting_point = ift.Field.full(space, 3.) starting_point = ift.Field.full(space, 3.).mult
class CoshEnergy(ift.Energy): class CoshEnergy(ift.Energy):
def __init__(self, position): def __init__(self, position):
...@@ -221,26 +221,26 @@ def test_cosh(minimizer): ...@@ -221,26 +221,26 @@ def test_cosh(minimizer):
@property @property
def value(self): def value(self):
x = self.position.val[0] x = self.position.sing.val[0]
return np.cosh(x) return np.cosh(x)
@property @property
def gradient(self): def gradient(self):
x = self.position.val[0] x = self.position.sing.val[0]
return ift.Field.full(self.position.domain, np.sinh(x)) return ift.full(self.position.domain, np.sinh(x))
@property @property
def metric(self): def metric(self):
x = self.position.val[0] x = self.position.sing.val[0]
v = np.cosh(x) v = np.cosh(x)
return ift.DiagonalOperator( return ift.DiagonalOperator(
ift.Field.full(self.position.domain, v)) ift.full(self.position.domain, v))
def apply_metric(self, x): def apply_metric(self, x):
p = self.position.val[0] p = self.position.sing.val[0]
v = np.cosh(p) v = np.cosh(p)
return ift.DiagonalOperator( return ift.DiagonalOperator(
ift.Field.full(self.position.domain, v))(x) ift.full(self.position.domain, v))(x)
try: try:
minimizer = eval(minimizer) minimizer = eval(minimizer)
...@@ -251,4 +251,4 @@ def test_cosh(minimizer): ...@@ -251,4 +251,4 @@ def test_cosh(minimizer):
raise SkipTest raise SkipTest
assert_equal(convergence, IC.CONVERGED) assert_equal(convergence, IC.CONVERGED)
assert_allclose(energy.position.val, 0., atol=1e-3) assert_allclose(energy.position.sing.val, 0., atol=1e-3)
...@@ -33,7 +33,7 @@ space = list2fixture([ ...@@ -33,7 +33,7 @@ space = list2fixture([
def test_property(space): def test_property(space):
diag = ift.Field.from_random('normal', domain=space) diag = ift.Field.from_random('normal', domain=space)
D = ift.DiagonalOperator(diag) D = ift.DiagonalOperator(diag)
if D.domain[0] != space: if D.domain.sing[0] != space:
raise TypeError raise TypeError
...@@ -42,8 +42,8 @@ def test_times_adjoint(space): ...@@ -42,8 +42,8 @@ def test_times_adjoint(space):
rand2 = ift.Field.from_random('normal', domain=space) rand2 = ift.Field.from_random('normal', domain=space)
diag = ift.Field.from_random('normal', domain=space) diag = ift.Field.from_random('normal', domain=space)
D = ift.DiagonalOperator(diag) D = ift.DiagonalOperator(diag)
tt1 = rand1.s_vdot(D.times(rand2)) tt1 = rand1.sing.s_vdot(D.times(rand2).sing)
tt2 = rand2.s_vdot(D.times(rand1)) tt2 = rand2.sing.s_vdot(D.times(rand1).sing)
assert_allclose(tt1, tt2) assert_allclose(tt1, tt2)
...@@ -52,7 +52,7 @@ def test_times_inverse(space): ...@@ -52,7 +52,7 @@ def test_times_inverse(space):
diag = ift.Field.from_random('normal', domain=space) diag = ift.Field.from_random('normal', domain=space)
D = ift.DiagonalOperator(diag) D = ift.DiagonalOperator(diag)
tt1 = D.times(D.inverse_times(rand1)) tt1 = D.times(D.inverse_times(rand1))
assert_allclose(rand1.val, tt1.val) assert_allclose(rand1.sing.val, tt1.sing.val)