Commit 01779e03 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more

parent f24e26e9
...@@ -32,7 +32,8 @@ __all__ = ["consistency_check", "check_jacobian_consistency", ...@@ -32,7 +32,8 @@ __all__ = ["consistency_check", "check_jacobian_consistency",
def assert_allclose(f1, f2, atol, rtol): def assert_allclose(f1, f2, atol, rtol):
if isinstance(f1, Field): if isinstance(f1, Field):
return np.testing.assert_allclose(f1.val, f2.val, atol=atol, rtol=rtol) np.testing.assert_allclose(f1.val, f2.val, atol=atol, rtol=rtol)
else:
for key, val in f1.items(): for key, val in f1.items():
assert_allclose(val, f2[key], atol=atol, rtol=rtol) assert_allclose(val, f2[key], atol=atol, rtol=rtol)
...@@ -103,10 +104,10 @@ def _actual_domain_check_nonlinear(op, loc): ...@@ -103,10 +104,10 @@ def _actual_domain_check_nonlinear(op, loc):
reslin = op(lin) reslin = op(lin)
assert_(lin.domain is op.domain) assert_(lin.domain is op.domain)
assert_(lin.target is op.domain) assert_(lin.target is op.domain)
assert_(lin.val.domain is lin.domain) assert_(lin.fld.domain is lin.domain)
assert_(reslin.domain is op.domain) assert_(reslin.domain is op.domain)
assert_(reslin.target is op.target) assert_(reslin.target is op.target)
assert_(reslin.val.domain is reslin.target) assert_(reslin.fld.domain is reslin.target)
assert_(reslin.target is op.target) assert_(reslin.target is op.target)
assert_(reslin.jac.domain is reslin.domain) assert_(reslin.jac.domain is reslin.domain)
assert_(reslin.jac.target is reslin.target) assert_(reslin.jac.target is reslin.target)
...@@ -150,7 +151,7 @@ def _performance_check(op, pos, raise_on_fail): ...@@ -150,7 +151,7 @@ def _performance_check(op, pos, raise_on_fail):
cond.append(cop.count != 2) cond.append(cop.count != 2)
lin.jac(pos) lin.jac(pos)
cond.append(cop.count != 3) cond.append(cop.count != 3)
lin.jac.adjoint(lin.val) lin.jac.adjoint(lin.fld)
cond.append(cop.count != 4) cond.append(cop.count != 4)
if lin.metric is not None: if lin.metric is not None:
lin.metric(pos) lin.metric(pos)
...@@ -217,20 +218,20 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64, ...@@ -217,20 +218,20 @@ def consistency_check(op, domain_dtype=np.float64, target_dtype=np.float64,
def _get_acceptable_location(op, loc, lin): def _get_acceptable_location(op, loc, lin):
if not np.isfinite(lin.val.s_sum()): if not np.isfinite(lin.fld.s_sum()):
raise ValueError('Initial value must be finite') raise ValueError('Initial value must be finite')
dir = from_random("normal", loc.domain) dir = from_random("normal", loc.domain)
dirder = lin.jac(dir) dirder = lin.jac(dir)
if dirder.norm() == 0: if dirder.norm() == 0:
dir = dir * (lin.val.norm()*1e-5) dir = dir * (lin.fld.norm()*1e-5)
else: else:
dir = dir * (lin.val.norm()*1e-5/dirder.norm()) dir = dir * (lin.fld.norm()*1e-5/dirder.norm())
# Find a step length that leads to a "reasonable" location # Find a step length that leads to a "reasonable" location
for i in range(50): for i in range(50):
try: try:
loc2 = loc+dir loc2 = loc+dir
lin2 = op(Linearization.make_var(loc2, lin.want_metric)) lin2 = op(Linearization.make_var(loc2, lin.want_metric))
if np.isfinite(lin2.val.s_sum()) and abs(lin2.val.s_sum()) < 1e20: if np.isfinite(lin2.fld.s_sum()) and abs(lin2.fld.s_sum()) < 1e20:
break break
except FloatingPointError: except FloatingPointError:
pass pass
...@@ -244,7 +245,7 @@ def _linearization_value_consistency(op, loc): ...@@ -244,7 +245,7 @@ def _linearization_value_consistency(op, loc):
for wm in [False, True]: for wm in [False, True]:
lin = Linearization.make_var(loc, wm) lin = Linearization.make_var(loc, wm)
fld0 = op(loc) fld0 = op(loc)
fld1 = op(lin).val fld1 = op(lin).fld
assert_allclose(fld0, fld1, 0, 1e-7) assert_allclose(fld0, fld1, 0, 1e-7)
...@@ -283,7 +284,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True): ...@@ -283,7 +284,7 @@ def check_jacobian_consistency(op, loc, tol=1e-8, ntries=100, perf_check=True):
locmid = loc + 0.5*dir locmid = loc + 0.5*dir
linmid = op(Linearization.make_var(locmid)) linmid = op(Linearization.make_var(locmid))
dirder = linmid.jac(dir) dirder = linmid.jac(dir)
numgrad = (lin2.val-lin.val) numgrad = (lin2.fld-lin.fld)
xtol = tol * dirder.norm() / np.sqrt(dirder.size) xtol = tol * dirder.norm() / np.sqrt(dirder.size)
hist.append((numgrad-dirder).norm()) hist.append((numgrad-dirder).norm())
# print(len(hist),hist[-1]) # print(len(hist),hist[-1])
......
...@@ -147,6 +147,10 @@ class Field(Operator): ...@@ -147,6 +147,10 @@ class Field(Operator):
arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs) arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs)
return Field(domain, arr) return Field(domain, arr)
@property
def fld(self):
return self
@property @property
def val(self): def val(self):
"""numpy.ndarray : the array storing the field's entries. """numpy.ndarray : the array storing the field's entries.
...@@ -172,6 +176,11 @@ class Field(Operator): ...@@ -172,6 +176,11 @@ class Field(Operator):
"""DomainTuple : the field's domain""" """DomainTuple : the field's domain"""
return self._domain return self._domain
@property
def target(self):
"""DomainTuple : the field's domain"""
return self._domain
@property @property
def shape(self): def shape(self):
"""tuple of int : the concatenated shapes of all sub-domains""" """tuple of int : the concatenated shapes of all sub-domains"""
......
...@@ -132,7 +132,7 @@ class LightConeOperator(Operator): ...@@ -132,7 +132,7 @@ class LightConeOperator(Operator):
def apply(self, x): def apply(self, x):
lin = x.jac is not None lin = x.jac is not None
a, derivs = _cone_arrays(x.val.val if lin else x.val, self.target, self._sigx, lin) a, derivs = _cone_arrays(x.val, self.target, self._sigx, lin)
res = Field(self.target, a) res = Field(self.target, a)
if not lin: if not lin:
return res return res
......
...@@ -79,11 +79,10 @@ class _InterpolationOperator(Operator): ...@@ -79,11 +79,10 @@ class _InterpolationOperator(Operator):
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
lin = x.jac is not None lin = x.jac is not None
xval = x.val.val if lin else x.val res = self._interpolator(x.val)
res = self._interpolator(xval)
res = Field(self._domain, res) res = Field(self._domain, res)
if lin: if lin:
res = x.new(res, makeOp(Field(self._domain, self._deriv(xval)))) res = x.new(res, makeOp(Field(self._domain, self._deriv(x.val))))
if self._inv_table_func is not None: if self._inv_table_func is not None:
res = self._inv_table_func(res) res = self._inv_table_func(res)
return res return res
...@@ -148,11 +147,10 @@ class UniformOperator(Operator): ...@@ -148,11 +147,10 @@ class UniformOperator(Operator):
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
lin = x.jac is not None lin = x.jac is not None
xval = x.val.val if lin else x.val res = Field(self._target, self._scale*norm._cdf(x.val) + self._loc)
res = Field(self._target, self._scale*norm._cdf(xval) + self._loc)
if not lin: if not lin:
return res return res
jac = makeOp(Field(self._domain, norm._pdf(xval)*self._scale)) jac = makeOp(Field(self._domain, norm._pdf(x.val)*self._scale))
return x.new(res, jac) return x.new(res, jac)
def inverse(self, field): def inverse(self, field):
......
...@@ -29,7 +29,7 @@ class Linearization(Operator): ...@@ -29,7 +29,7 @@ class Linearization(Operator):
Parameters Parameters
---------- ----------
val : Field or MultiField fld : Field or MultiField
The value of the operator application. The value of the operator application.
jac : LinearOperator jac : LinearOperator
The Jacobian. The Jacobian.
...@@ -39,38 +39,38 @@ class Linearization(Operator): ...@@ -39,38 +39,38 @@ class Linearization(Operator):
If True, the metric will be computed for other Linearizations derived If True, the metric will be computed for other Linearizations derived
from this one. Default: False. from this one. Default: False.
""" """
def __init__(self, val, jac, metric=None, want_metric=False): def __init__(self, fld, jac, metric=None, want_metric=False):
self._val = val self._fld = fld
self._jac = jac self._jac = jac
if self._val.domain != self._jac.target: if self._fld.domain != self._jac.target:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
self._want_metric = want_metric self._want_metric = want_metric
self._metric = metric self._metric = metric
def new(self, val, jac, metric=None): def new(self, fld, jac, metric=None):
"""Create a new Linearization, taking the `want_metric` property from """Create a new Linearization, taking the `want_metric` property from
this one. this one.
Parameters Parameters
---------- ----------
val : Field or MultiField fld : Field or MultiField
the value of the operator application the value of the operator application
jac : LinearOperator jac : LinearOperator
the Jacobian the Jacobian
metric : LinearOperator or None metric : LinearOperator or None
The metric. Default: None. The metric. Default: None.
""" """
return Linearization(val, jac, metric, self._want_metric) return Linearization(fld, jac, metric, self._want_metric)
def trivial_jac(self): def trivial_jac(self):
return self.make_var(self._val, self._want_metric) return self.make_var(self._fld, self._want_metric)
def prepend_jac(self, jac): def prepend_jac(self, jac):
metric = None metric = None
if self._metric is not None: if self._metric is not None:
from .operators.sandwich_operator import SandwichOperator from .operators.sandwich_operator import SandwichOperator
metric = None if self._metric is None else SandwichOperator.make(jac, self._metric) metric = None if self._metric is None else SandwichOperator.make(jac, self._metric)
return self.new(self._val, self._jac @ jac, metric) return self.new(self._fld, self._jac @ jac, metric)
@property @property
def domain(self): def domain(self):
...@@ -82,10 +82,19 @@ class Linearization(Operator): ...@@ -82,10 +82,19 @@ class Linearization(Operator):
"""DomainTuple or MultiDomain : the Jacobian's target (i.e. the value's domain)""" """DomainTuple or MultiDomain : the Jacobian's target (i.e. the value's domain)"""
return self._jac.target return self._jac.target
@property
def fld(self):
"""Field or MultiField : the pure field-like part of this object"""
return self._fld
@property @property
def val(self): def val(self):
"""Field or MultiField : the value""" """numpy.ndarray or {key: numpy.ndarray} : the numerical value data"""
return self._val return self._fld.val
def val_rw(self):
"""numpy.ndarray or {key: numpy.ndarray} : the numerical value data"""
return self._fld.val_rw()
@property @property
def jac(self): def jac(self):
...@@ -119,30 +128,30 @@ class Linearization(Operator): ...@@ -119,30 +128,30 @@ class Linearization(Operator):
return self._metric return self._metric
def __getitem__(self, name): def __getitem__(self, name):
return self.new(self._val[name], self._jac.ducktape_left(name)) return self.new(self._fld[name], self._jac.ducktape_left(name))
def __neg__(self): def __neg__(self):
return self.new(-self._val, -self._jac, return self.new(-self._fld, -self._jac,
None if self._metric is None else -self._metric) None if self._metric is None else -self._metric)
def conjugate(self): def conjugate(self):
return self.new( return self.new(
self._val.conjugate(), self._jac.conjugate(), self._fld.conjugate(), self._jac.conjugate(),
None if self._metric is None else self._metric.conjugate()) None if self._metric is None else self._metric.conjugate())
@property @property
def real(self): def real(self):
return self.new(self._val.real, self._jac.real) return self.new(self._fld.real, self._jac.real)
def _myadd(self, other, neg): def _myadd(self, other, neg):
if np.isscalar(other) or other.jac is None: if np.isscalar(other) or other.jac is None:
return self.new(self._val-other if neg else self._val+other, return self.new(self.fld-other if neg else self.fld+other,
self._jac, self._metric) self._jac, self._metric)
met = None met = None
if self._metric is not None and other._metric is not None: if self._metric is not None and other._metric is not None:
met = self._metric._myadd(other._metric, neg) met = self._metric._myadd(other._metric, neg)
return self.new( return self.new(
self.val.flexible_addsub(other.val, neg), self.fld.flexible_addsub(other.fld, neg),
self.jac._myadd(other.jac, neg), met) self.jac._myadd(other.jac, neg), met)
def __add__(self, other): def __add__(self, other):
...@@ -175,18 +184,18 @@ class Linearization(Operator): ...@@ -175,18 +184,18 @@ class Linearization(Operator):
if other == 1: if other == 1:
return self return self
met = None if self._metric is None else self._metric.scale(other) met = None if self._metric is None else self._metric.scale(other)
return self.new(self._val*other, self._jac.scale(other), met) return self.new(self.fld*other, self._jac.scale(other), met)
from .sugar import makeOp from .sugar import makeOp
if other.jac is None: if other.jac is None:
if self.target != other.domain: if self.target != other.domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
return self.new(self._val*other, makeOp(other)(self._jac)) return self.new(self.fld*other, makeOp(other)(self._jac))
if self.target != other.target: if self.target != other.target:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
return self.new( return self.new(
self.val*other.val, self.fld*other.fld,
(makeOp(other.val)(self.jac))._myadd( (makeOp(other.fld)(self.jac))._myadd(
makeOp(self.val)(other.jac), False)) makeOp(self.fld)(other.jac), False))
def __rmul__(self, other): def __rmul__(self, other):
return self.__mul__(other) return self.__mul__(other)
...@@ -208,12 +217,12 @@ class Linearization(Operator): ...@@ -208,12 +217,12 @@ class Linearization(Operator):
return self.__mul__(other) return self.__mul__(other)
from .operators.outer_product_operator import OuterProduct from .operators.outer_product_operator import OuterProduct
if other.jac is None: if other.jac is None:
return self.new(OuterProduct(self._val, other.domain)(other), return self.new(OuterProduct(self._fld, other.domain)(other),
OuterProduct(self._jac(self._val), other.domain)) OuterProduct(self._jac(self._fld), other.domain))
return self.new( return self.new(
OuterProduct(self._val, other.target)(other._val), OuterProduct(self._fld, other.target)(other._fld),
OuterProduct(self._jac(self._val), other.target)._myadd( OuterProduct(self._jac(self._fld), other.target)._myadd(
OuterProduct(self._val, other.target)(other._jac), False)) OuterProduct(self._fld, other.target)(other._jac), False))
def vdot(self, other): def vdot(self, other):
"""Computes the inner product of this Linearization with a Field or """Computes the inner product of this Linearization with a Field or
...@@ -229,14 +238,18 @@ class Linearization(Operator): ...@@ -229,14 +238,18 @@ class Linearization(Operator):
the inner product of self and other the inner product of self and other
""" """
from .operators.simple_linear_operators import VdotOperator from .operators.simple_linear_operators import VdotOperator
if other is self:
return self.new(
self._fld.vdot(self._fld),
VdotOperator(2*self._fld)(self._jac))
if other.jac is None: if other.jac is None:
return self.new( return self.new(
self._val.vdot(other), self._fld.vdot(other),
VdotOperator(other)(self._jac)) VdotOperator(other)(self._jac))
return self.new( return self.new(
self._val.vdot(other._val), self._fld.vdot(other._fld),
VdotOperator(self._val)(other._jac) + VdotOperator(self._fld)(other._jac) +
VdotOperator(other._val)(self._jac)) VdotOperator(other._fld)(self._jac))
def sum(self, spaces=None): def sum(self, spaces=None):
"""Computes the (partial) sum over self """Computes the (partial) sum over self
...@@ -254,7 +267,7 @@ class Linearization(Operator): ...@@ -254,7 +267,7 @@ class Linearization(Operator):
""" """
from .operators.contraction_operator import ContractionOperator from .operators.contraction_operator import ContractionOperator
return self.new( return self.new(
self._val.sum(spaces), self._fld.sum(spaces),
ContractionOperator(self._jac.target, spaces)(self._jac)) ContractionOperator(self._jac.target, spaces)(self._jac))
def integrate(self, spaces=None): def integrate(self, spaces=None):
...@@ -273,12 +286,12 @@ class Linearization(Operator): ...@@ -273,12 +286,12 @@ class Linearization(Operator):
""" """
from .operators.contraction_operator import ContractionOperator from .operators.contraction_operator import ContractionOperator
return self.new( return self.new(
self._val.integrate(spaces), self._fld.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac)) ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def ptw(self, op, *args, **kwargs): def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict from .pointwise import ptw_dict
t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs) t1, t2 = self._fld.ptw_with_deriv(op, *args, **kwargs)
return self.new(t1, makeOp(t2)(self._jac)) return self.new(t1, makeOp(t2)(self._jac))
def clip(self, a_min=None, a_max=None): def clip(self, a_min=None, a_max=None):
...@@ -291,10 +304,10 @@ class Linearization(Operator): ...@@ -291,10 +304,10 @@ class Linearization(Operator):
return self.ptw("clip", a_min, a_max) return self.ptw("clip", a_min, a_max)
def add_metric(self, metric): def add_metric(self, metric):
return self.new(self._val, self._jac, metric) return self.new(self._fld, self._jac, metric)
def with_want_metric(self): def with_want_metric(self):
return Linearization(self._val, self._jac, self._metric, True) return Linearization(self._fld, self._jac, self._metric, True)
@staticmethod @staticmethod
def make_var(field, want_metric=False): def make_var(field, want_metric=False):
......
...@@ -47,7 +47,7 @@ class EnergyAdapter(Energy): ...@@ -47,7 +47,7 @@ class EnergyAdapter(Energy):
self._want_metric = want_metric self._want_metric = want_metric
lin = Linearization.make_partial_var(position, constants, want_metric) lin = Linearization.make_partial_var(position, constants, want_metric)
tmp = self._op(lin) tmp = self._op(lin)
self._val = tmp.val.val[()] self._val = tmp.val[()]
self._grad = tmp.gradient self._grad = tmp.gradient
self._metric = tmp._metric self._metric = tmp._metric
......
...@@ -198,10 +198,10 @@ class MetricGaussianKL(Energy): ...@@ -198,10 +198,10 @@ class MetricGaussianKL(Energy):
if self._mirror_samples: if self._mirror_samples:
tmp = tmp + self._hamiltonian(self._lin-s) tmp = tmp + self._hamiltonian(self._lin-s)
if v is None: if v is None:
v = tmp.val.val_rw() v = tmp.val_rw()
g = tmp.gradient g = tmp.gradient
else: else:
v += tmp.val.val v += tmp.val
g = g + tmp.gradient g = g + tmp.gradient
self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
......
...@@ -83,6 +83,10 @@ class MultiField(Operator): ...@@ -83,6 +83,10 @@ class MultiField(Operator):
def domain(self): def domain(self):
return self._domain return self._domain
@property
def target(self):
return self._domain
# @property # @property
# def dtype(self): # def dtype(self):
# return {key: val.dtype for key, val in self._val.items()} # return {key: val.dtype for key, val in self._val.items()}
...@@ -136,6 +140,10 @@ class MultiField(Operator): ...@@ -136,6 +140,10 @@ class MultiField(Operator):
return MultiField(domain, tuple(Field(dom, val) return MultiField(domain, tuple(Field(dom, val)
for dom in domain._domains)) for dom in domain._domains))
@property
def fld(self):
return self
@property @property
def val(self): def val(self):
return {key: val.val return {key: val.val
......
...@@ -58,10 +58,10 @@ class Squared2NormOperator(EnergyOperator): ...@@ -58,10 +58,10 @@ class Squared2NormOperator(EnergyOperator):
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
res = x.fld.vdot(x.fld)
if x.jac is None: if x.jac is None:
return x.vdot(x) return res
res = x.val.vdot(x.val) return x.new(res, VdotOperator(2*x.fld))
return x.new(res, VdotOperator(2*x.val))
class QuadraticFormOperator(EnergyOperator): class QuadraticFormOperator(EnergyOperator):
...@@ -86,10 +86,10 @@ class QuadraticFormOperator(EnergyOperator): ...@@ -86,10 +86,10 @@ class QuadraticFormOperator(EnergyOperator):