Commit 01779e03 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more

parent f24e26e9
......@@ -32,9 +32,10 @@ __all__ = ["consistency_check", "check_jacobian_consistency",
def assert_allclose(f1, f2, atol, rtol):
if isinstance(f1, Field):
return np.testing.assert_allclose(f1.val, f2.val, atol=atol, rtol=rtol)
for key, val in f1.items():
assert_allclose(val, f2[key], atol=atol, rtol=rtol)
np.testing.assert_allclose(f1.val, f2.val, atol=atol, rtol=rtol)
else:
for key, val in f1.items():
assert_allclose(val, f2[key], atol=atol, rtol=rtol)
def _adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol,
......@@ -103,10 +104,10 @@ def _actual_domain_check_nonlinear(op, loc):
reslin = op(lin)
assert_(lin.domain is op.domain)
assert_(lin.target is op.domain)
assert_(lin.val.domain is lin.domain)
assert_(lin.fld.domain is lin.domain)
assert_(reslin.domain is op.domain)
assert_(reslin.target is op.target)
assert_(reslin.val.domain is reslin.target)
assert_(reslin.fld.domain is reslin.target)
assert_(reslin.target is op.target)
assert_(reslin.jac.domain is reslin.domain)
assert_(reslin.jac.target is reslin.target)
......@@ -150,7 +151,7 @@ def _performance_check(op, pos, raise_on_fail):
cond.append(cop.count != 2)
lin.jac(pos)
cond.append(cop.count != 3)
lin.jac.adjoint(lin.val)
lin.jac.adjoint(lin.fld)
cond.append(cop.count != 4)
if lin.metric is not None:
lin.metric(pos)
......@@ -217,20 +218,20 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
def _get_acceptable_location(op, loc, lin):
if not np.isfinite(lin.val.s_sum()):
if not np.isfinite(lin.fld.s_sum()):
raise ValueError('Initial value must be finite')
dir = from_random("normal", loc.domain)
dirder = lin.jac(dir)
if dirder.norm() == 0:
dir = dir * (lin.val.norm()*1e-5)
dir = dir * (lin.fld.norm()*1e-5)
else:
dir = dir * (lin.val.norm()*1e-5/dirder.norm())
dir = dir * (lin.fld.norm()*1e-5/dirder.norm())
# Find a step length that leads to a "reasonable" location
for i in range(50):
try:
loc2 = loc+dir
lin2 = op(Linearization.make_var(loc2, lin.want_metric))
if np.isfinite(lin2.val.s_sum()) and abs(lin2.val.s_sum()) < 1e20:
if np.isfinite(lin2.fld.s_sum()) and abs(lin2.fld.s_sum()) < 1e20:
break
except FloatingPointError:
pass
......@@ -244,7 +245,7 @@ def _linearization_value_consistency(op, loc):
for wm in [False, True]:
lin = Linearization.make_var(loc, wm)
fld0 = op(loc)
fld1 = op(lin).val
fld1 = op(lin).fld
assert_allclose(fld0, fld1, 0, 1e-7)
......@@ -283,7 +284,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True):
locmid = loc + 0.5*dir
linmid = op(Linearization.make_var(locmid))
dirder = linmid.jac(dir)
numgrad = (lin2.val-lin.val)
numgrad = (lin2.fld-lin.fld)
xtol = tol * dirder.norm() / np.sqrt(dirder.size)
hist.append((numgrad-dirder).norm())
# print(len(hist),hist[-1])
......
......@@ -147,6 +147,10 @@ class Field(Operator):
arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs)
return Field(domain, arr)
@property
def fld(self):
return self
@property
def val(self):
"""numpy.ndarray : the array storing the field's entries.
......@@ -172,6 +176,11 @@ class Field(Operator):
"""DomainTuple : the field's domain"""
return self._domain
@property
def target(self):
"""DomainTuple : the field's domain"""
return self._domain
@property
def shape(self):
"""tuple of int : the concatenated shapes of all sub-domains"""
......
......@@ -132,7 +132,7 @@ class LightConeOperator(Operator):
def apply(self, x):
lin = x.jac is not None
a, derivs = _cone_arrays(x.val.val if lin else x.val, self.target, self._sigx, lin)
a, derivs = _cone_arrays(x.val, self.target, self._sigx, lin)
res = Field(self.target, a)
if not lin:
return res
......
......@@ -79,11 +79,10 @@ class _InterpolationOperator(Operator):
def apply(self, x):
self._check_input(x)
lin = x.jac is not None
xval = x.val.val if lin else x.val
res = self._interpolator(xval)
res = self._interpolator(x.val)
res = Field(self._domain, res)
if lin:
res = x.new(res, makeOp(Field(self._domain, self._deriv(xval))))
res = x.new(res, makeOp(Field(self._domain, self._deriv(x.val))))
if self._inv_table_func is not None:
res = self._inv_table_func(res)
return res
......@@ -148,11 +147,10 @@ class UniformOperator(Operator):
def apply(self, x):
self._check_input(x)
lin = x.jac is not None
xval = x.val.val if lin else x.val
res = Field(self._target, self._scale*norm._cdf(xval) + self._loc)
res = Field(self._target, self._scale*norm._cdf(x.val) + self._loc)
if not lin:
return res
jac = makeOp(Field(self._domain, norm._pdf(xval)*self._scale))
jac = makeOp(Field(self._domain, norm._pdf(x.val)*self._scale))
return x.new(res, jac)
def inverse(self, field):
......
......@@ -29,7 +29,7 @@ class Linearization(Operator):
Parameters
----------
val : Field or MultiField
fld : Field or MultiField
The value of the operator application.
jac : LinearOperator
The Jacobian.
......@@ -39,38 +39,38 @@ class Linearization(Operator):
If True, the metric will be computed for other Linearizations derived
from this one. Default: False.
"""
def __init__(self, val, jac, metric=None, want_metric=False):
self._val = val
def __init__(self, fld, jac, metric=None, want_metric=False):
self._fld = fld
self._jac = jac
if self._val.domain != self._jac.target:
if self._fld.domain != self._jac.target:
raise ValueError("domain mismatch")
self._want_metric = want_metric
self._metric = metric
def new(self, val, jac, metric=None):
def new(self, fld, jac, metric=None):
"""Create a new Linearization, taking the `want_metric` property from
this one.
Parameters
----------
val : Field or MultiField
fld : Field or MultiField
the value of the operator application
jac : LinearOperator
the Jacobian
metric : LinearOperator or None
The metric. Default: None.
"""
return Linearization(val, jac, metric, self._want_metric)
return Linearization(fld, jac, metric, self._want_metric)
def trivial_jac(self):
return self.make_var(self._val, self._want_metric)
return self.make_var(self._fld, self._want_metric)
def prepend_jac(self, jac):
metric = None
if self._metric is not None:
from .operators.sandwich_operator import SandwichOperator
metric = None if self._metric is None else SandwichOperator.make(jac, self._metric)
return self.new(self._val, self._jac @ jac, metric)
return self.new(self._fld, self._jac @ jac, metric)
@property
def domain(self):
......@@ -82,10 +82,19 @@ class Linearization(Operator):
"""DomainTuple or MultiDomain : the Jacobian's target (i.e. the value's domain)"""
return self._jac.target
@property
def fld(self):
"""Field or MultiField : the pure field-like part of this object"""
return self._fld
@property
def val(self):
"""Field or MultiField : the value"""
return self._val
"""numpy.ndarray or {key: numpy.ndarray} : the numerical value data"""
return self._fld.val
def val_rw(self):
"""numpy.ndarray or {key: numpy.ndarray} : the numerical value data"""
return self._fld.val_rw()
@property
def jac(self):
......@@ -119,30 +128,30 @@ class Linearization(Operator):
return self._metric
def __getitem__(self, name):
return self.new(self._val[name], self._jac.ducktape_left(name))
return self.new(self._fld[name], self._jac.ducktape_left(name))
def __neg__(self):
return self.new(-self._val, -self._jac,
return self.new(-self._fld, -self._jac,
None if self._metric is None else -self._metric)
def conjugate(self):
return self.new(
self._val.conjugate(), self._jac.conjugate(),
self._fld.conjugate(), self._jac.conjugate(),
None if self._metric is None else self._metric.conjugate())
@property
def real(self):
return self.new(self._val.real, self._jac.real)
return self.new(self._fld.real, self._jac.real)
def _myadd(self, other, neg):
if np.isscalar(other) or other.jac is None:
return self.new(self._val-other if neg else self._val+other,
return self.new(self.fld-other if neg else self.fld+other,
self._jac, self._metric)
met = None
if self._metric is not None and other._metric is not None:
met = self._metric._myadd(other._metric, neg)
return self.new(
self.val.flexible_addsub(other.val, neg),
self.fld.flexible_addsub(other.fld, neg),
self.jac._myadd(other.jac, neg), met)
def __add__(self, other):
......@@ -175,18 +184,18 @@ class Linearization(Operator):
if other == 1:
return self
met = None if self._metric is None else self._metric.scale(other)
return self.new(self._val*other, self._jac.scale(other), met)
return self.new(self.fld*other, self._jac.scale(other), met)
from .sugar import makeOp
if other.jac is None:
if self.target != other.domain:
raise ValueError("domain mismatch")
return self.new(self._val*other, makeOp(other)(self._jac))
return self.new(self.fld*other, makeOp(other)(self._jac))
if self.target != other.target:
raise ValueError("domain mismatch")
return self.new(
self.val*other.val,
(makeOp(other.val)(self.jac))._myadd(
makeOp(self.val)(other.jac), False))
self.fld*other.fld,
(makeOp(other.fld)(self.jac))._myadd(
makeOp(self.fld)(other.jac), False))
def __rmul__(self, other):
return self.__mul__(other)
......@@ -208,12 +217,12 @@ class Linearization(Operator):
return self.__mul__(other)
from .operators.outer_product_operator import OuterProduct
if other.jac is None:
return self.new(OuterProduct(self._val, other.domain)(other),
OuterProduct(self._jac(self._val), other.domain))
return self.new(OuterProduct(self._fld, other.domain)(other),
OuterProduct(self._jac(self._fld), other.domain))
return self.new(
OuterProduct(self._val, other.target)(other._val),
OuterProduct(self._jac(self._val), other.target)._myadd(
OuterProduct(self._val, other.target)(other._jac), False))
OuterProduct(self._fld, other.target)(other._fld),
OuterProduct(self._jac(self._fld), other.target)._myadd(
OuterProduct(self._fld, other.target)(other._jac), False))
def vdot(self, other):
"""Computes the inner product of this Linearization with a Field or
......@@ -229,14 +238,18 @@ class Linearization(Operator):
the inner product of self and other
"""
from .operators.simple_linear_operators import VdotOperator
if other is self:
return self.new(
self._fld.vdot(self._fld),
VdotOperator(2*self._fld)(self._jac))
if other.jac is None:
return self.new(
self._val.vdot(other),
self._fld.vdot(other),
VdotOperator(other)(self._jac))
return self.new(
self._val.vdot(other._val),
VdotOperator(self._val)(other._jac) +
VdotOperator(other._val)(self._jac))
self._fld.vdot(other._fld),
VdotOperator(self._fld)(other._jac) +
VdotOperator(other._fld)(self._jac))
def sum(self, spaces=None):
"""Computes the (partial) sum over self
......@@ -254,7 +267,7 @@ class Linearization(Operator):
"""
from .operators.contraction_operator import ContractionOperator
return self.new(
self._val.sum(spaces),
self._fld.sum(spaces),
ContractionOperator(self._jac.target, spaces)(self._jac))
def integrate(self, spaces=None):
......@@ -273,12 +286,12 @@ class Linearization(Operator):
"""
from .operators.contraction_operator import ContractionOperator
return self.new(
self._val.integrate(spaces),
self._fld.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict
t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs)
t1, t2 = self._fld.ptw_with_deriv(op, *args, **kwargs)
return self.new(t1, makeOp(t2)(self._jac))
def clip(self, a_min=None, a_max=None):
......@@ -291,10 +304,10 @@ class Linearization(Operator):
return self.ptw("clip", a_min, a_max)
def add_metric(self, metric):
return self.new(self._val, self._jac, metric)
return self.new(self._fld, self._jac, metric)
def with_want_metric(self):
return Linearization(self._val, self._jac, self._metric, True)
return Linearization(self._fld, self._jac, self._metric, True)
@staticmethod
def make_var(field, want_metric=False):
......
......@@ -47,7 +47,7 @@ class EnergyAdapter(Energy):
self._want_metric = want_metric
lin = Linearization.make_partial_var(position, constants, want_metric)
tmp = self._op(lin)
self._val = tmp.val.val[()]
self._val = tmp.val[()]
self._grad = tmp.gradient
self._metric = tmp._metric
......
......@@ -198,10 +198,10 @@ class MetricGaussianKL(Energy):
if self._mirror_samples:
tmp = tmp + self._hamiltonian(self._lin-s)
if v is None:
v = tmp.val.val_rw()
v = tmp.val_rw()
g = tmp.gradient
else:
v += tmp.val.val
v += tmp.val
g = g + tmp.gradient
self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
......
......@@ -83,6 +83,10 @@ class MultiField(Operator):
def domain(self):
return self._domain
@property
def target(self):
return self._domain
# @property
# def dtype(self):
# return {key: val.dtype for key, val in self._val.items()}
......@@ -136,6 +140,10 @@ class MultiField(Operator):
return MultiField(domain, tuple(Field(dom, val)
for dom in domain._domains))
@property
def fld(self):
return self
@property
def val(self):
return {key: val.val
......
......@@ -58,10 +58,10 @@ class Squared2NormOperator(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = x.fld.vdot(x.fld)
if x.jac is None:
return x.vdot(x)
res = x.val.vdot(x.val)
return x.new(res, VdotOperator(2*x.val))
return res
return x.new(res, VdotOperator(2*x.fld))
class QuadraticFormOperator(EnergyOperator):
......@@ -86,10 +86,10 @@ class QuadraticFormOperator(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = 0.5*x.fld.vdot(self._op(x.fld))
if x.jac is None:
return 0.5*x.vdot(self._op(x))
res = 0.5*x.val.vdot(self._op(x.val))
return x.new(res, VdotOperator(self._op(x.val)))
return res
return x.new(res, VdotOperator(self._op(x.fld)))
class VariableCovarianceGaussianEnergy(EnergyOperator):
......@@ -128,7 +128,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].ptw("log").sum())
if not x.want_metric:
return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
mf = {self._r: x.fld[self._icov], self._icov: .5*x.fld[self._icov]**(-2)}
return res.add_metric(makeOp(MultiField.from_dict(mf)))
......@@ -230,7 +230,7 @@ class PoissonianEnergy(EnergyOperator):
res = x.sum() - x.ptw("log").vdot(self._d)
if not x.want_metric:
return res
return res.add_metric(makeOp(1./x.val))
return res.add_metric(makeOp(x.fld.ptw("reciprocal")))
class InverseGammaLikelihood(EnergyOperator):
......@@ -270,7 +270,7 @@ class InverseGammaLikelihood(EnergyOperator):
res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
if not x.want_metric:
return res
return res.add_metric(makeOp(self._alphap1/(x.val**2)))
return res.add_metric(makeOp(self._alphap1/(x.fld**2)))
class StudentTEnergy(EnergyOperator):
......@@ -333,7 +333,7 @@ class BernoulliEnergy(EnergyOperator):
res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
if not x.want_metric:
return res
return res.add_metric(makeOp(1./(x.val*(1. - x.val))))
return res.add_metric(makeOp(1./(x.fld*(1. - x.fld))))
class StandardHamiltonian(EnergyOperator):
......
......@@ -172,7 +172,7 @@ class LinearOperator(Operator):
"""Same as :meth:`times`"""
from ..linearization import Linearization
if x.jac is not None:
return x.new(self(x._val), self).prepend_jac(x.jac)
return x.new(self(x.fld), self).prepend_jac(x.jac)
if x.val is not None:
return self.apply(x, self.TIMES)
return self@x
......
......@@ -45,11 +45,23 @@ class Operator(metaclass=NiftyMeta):
"""
return self._target
@property
def fld(self):
"""The field associated with this object
For "pure" operators this is `None`. For Field-like objects this
is a `Field` or a `MultiField` matching the object's `target`.
Returns
-------
None or Field or MultiField : the field object
"""
return None
@property
def val(self):
"""The numerical value associated with this object
For "pure" operators this is `None`. For Field-like objects this
is a `numpy.ndarray` or a dictionary of `numpy.ndarray`s mathcing the
is a `numpy.ndarray` or a dictionary of `numpy.ndarray`s matching the
object's `target`.
Returns
......@@ -421,16 +433,16 @@ class _OpProd(Operator):
from ..sugar import makeOp
self._check_input(x)
lin = x.jac is not None
wm = x.want_metric if lin else False
x = x.val if lin else x
wm = x.want_metric
x = x.fld if lin else x
v1 = x.extract(self._op1.domain)
v2 = x.extract(self._op2.domain)
if not lin:
return self._op1(v1) * self._op2(v2)
lin1 = self._op1(Linearization.make_var(v1, wm))
lin2 = self._op2(Linearization.make_var(v2, wm))
jac = (makeOp(lin1._val)(lin2._jac))._myadd(makeOp(lin2._val)(lin1._jac), False)
return lin1.new(lin1._val*lin2._val, jac)
jac = (makeOp(lin1._fld)(lin2._jac))._myadd(makeOp(lin2._fld)(lin1._jac), False)
return lin1.new(lin1._fld*lin2._fld, jac)
def _simplify_for_constant_input_nontrivial(self, c_inp):
f1, o1 = self._op1.simplify_for_constant_input(
......@@ -467,13 +479,13 @@ class _OpSum(Operator):
v1 = x.extract(self._op1.domain)
v2 = x.extract(self._op2.domain)
return self._op1(v1).unite(self._op2(v2))
v1 = x.val.extract(self._op1.domain)
v2 = x.val.extract(self._op2.domain)
v1 = x.fld.extract(self._op1.domain)
v2 = x.fld.extract(self._op2.domain)
wm = x.want_metric
lin1 = self._op1(Linearization.make_var(v1, wm))
lin2 = self._op2(Linearization.make_var(v2, wm))
op = lin1._jac._myadd(lin2._jac, False)
res = lin1.new(lin1._val.unite(lin2._val), op)
res = lin1.new(lin1._fld.unite(lin2._fld), op)
if lin1._metric is not None and lin2._metric is not None:
res = res.add_metric(lin1._metric._myadd(lin2._metric, False))
return res
......
......@@ -63,8 +63,8 @@ def test_actual_gradients(f):
eps = 1e-8
var0 = ift.Linearization.make_var(fld)
var1 = ift.Linearization.make_var(fld + eps)
f0 = var0.ptw(f).val.val
f1 = var1.ptw(f).val.val
f0 = var0.ptw(f).val
f1 = var1.ptw(f).val
df0 = (f1 - f0)/eps
df1 = _lin2grad(var0.ptw(f))
assert_allclose(df0, df1, rtol=100*eps)
......@@ -43,7 +43,7 @@ def testBasics(space, seed):
s = S.draw_sample()
var = ift.Linearization.make_var(s)
model = ift.ScalingOperator(var.target, 6.)
ift.extra.check_jacobian_consistency(model, var.val)
ift.extra.check_jacobian_consistency(model, var.fld)
@pmp('type1', ['Variable', 'Constant'])
......
......@@ -46,10 +46,10 @@ def test_simplification():
o2.ducktape("b").ducktape_left("b"))
_, op2 = op.simplify_for_constant_input(f2)
assert_equal(isinstance(op2._op1, _ConstantOperator), True)
assert_allclose(op(f1)["a"].val, op2(f1)["a"].val)
assert_allclose(op(f1)["b"].val, op2(f1)["b"].val)
assert_allclose(op(f1).val["a"], op2(f1).val["a"])
assert_allclose(op(f1).val["b"], op2(f1).val["b"])
lin = ift.Linearization.make_var(ift.MultiField.full(op2.domain, 2.), True)
assert_allclose(op(lin).val["a"].val,
op2(lin).val["a"].val)
assert_allclose(op(lin).val["b"].val,
op2(lin).val["b"].val)
assert_allclose(op(lin).val["a"],
op2(lin).val["a"])
assert_allclose(op(lin).val["b"],
op2(lin).val["b"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment