Commit 2e57bec1 authored by Martin Reinecke's avatar Martin Reinecke

rework pointwise operations

parent 8a3b5d58
Pipeline #72233 canceled with stages
in 1 minute and 49 seconds
......@@ -142,8 +142,7 @@ class RGSpace(StructuredDomain):
@staticmethod
def _kernel(x, sigma):
from ..sugar import exp
return exp(x*x * (-2.*np.pi*np.pi*sigma*sigma))
return (x*x * (-2.*np.pi*np.pi*sigma*sigma)).ptw("exp")
def get_fft_smoothing_kernel_function(self, sigma):
if (not self.harmonic):
......
......@@ -634,10 +634,9 @@ class Field(object):
Field
The result of the operation.
"""
from .sugar import sqrt
if self.scalar_weight(spaces) is not None:
return self._contraction_helper('std', spaces)
return sqrt(self.var(spaces))
return self.var(spaces).ptw("sqrt")
def s_std(self):
"""Determines the standard deviation of the Field.
......@@ -677,17 +676,11 @@ class Field(object):
def flexible_addsub(self, other, neg):
return self-other if neg else self+other
def sigmoid(self):
return 0.5*(1.+self.tanh())
def clip(self, min=None, max=None):
min = min.val if isinstance(min, Field) else min
max = max.val if isinstance(max, Field) else max
return Field(self._domain, np.clip(self._val, min, max))
def one_over(self):
return 1/self
def _binary_op(self, other, op):
# if other is a field, make sure that the domains match
f = getattr(self._val, op)
......@@ -699,6 +692,13 @@ class Field(object):
return Field(self._domain, f(other))
return NotImplemented
def ptw(self, op, with_deriv=False):
from .pointwise import ptw_dict
if with_deriv:
tmp = ptw_dict[op][1](self._val)
return (Field(self._domain, tmp[0]),
Field(self._domain, tmp[1]))
return Field(self._domain, ptw_dict[op][0](self._val))
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......@@ -721,11 +721,3 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported")
return func2
setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh",
"absolute", "sinc", "sign", "log10", "log1p", "expm1"]:
def func(f):
def func2(self):
return Field(self._domain, getattr(np, f)(self.val))
return func2
setattr(Field, f, func(f))
......@@ -126,7 +126,7 @@ class _LognormalMomentMatching(Operator):
logmean, logsig = _lognormal_moments(mean, sig, N_copies)
self._mean = mean
self._sig = sig
op = _normal(logmean, logsig, key, N_copies).exp()
op = _normal(logmean, logsig, key, N_copies).ptw("exp")
self._domain, self._target = op.domain, op.target
self.apply = op.apply
......@@ -224,8 +224,8 @@ class _Normalization(Operator):
def apply(self, x):
self._check_input(x)
amp = x.exp()
spec = (2*x).exp()
amp = x.ptw("exp")
spec = amp*amp
# FIXME This normalizes also the zeromode which is supposed to be left
# untouched by this operator
return self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp
......@@ -332,17 +332,17 @@ class _Amplitude(Operator):
sig_fluc = vol1 @ ps_expander @ fluctuations
xi = ducktape(dom, None, key)
sigma = sig_flex*(Adder(shift) @ sig_asp).sqrt()
sigma = sig_flex*(Adder(shift) @ sig_asp).ptw("sqrt")
smooth = _SlopeRemover(target, space) @ twolog @ (sigma*xi)
op = _Normalization(target, space) @ (slope + smooth)
if N_copies > 0:
op = Distributor @ op
sig_fluc = Distributor @ sig_fluc
op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op)
op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.ptw("reciprocal"))*op)
self._fluc = (_Distributor(dofdex, fluctuations.target,
distributed_tgt[0]) @ fluctuations)
else:
op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.one_over())*op)
op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.ptw("reciprocal"))*op)
self._fluc = fluctuations
self.apply = op.apply
......@@ -527,7 +527,7 @@ class CorrelatedFieldMaker:
for _ in range(prior_info):
sc.add(op(from_random('normal', op.domain)))
mean = sc.mean.val
stddev = sc.var.sqrt().val
stddev = sc.var.ptw("sqrt").val
for m, s in zip(mean.flatten(), stddev.flatten()):
logger.info('{}: {:.02E} ± {:.02E}'.format(kk, m, s))
......@@ -539,7 +539,7 @@ class CorrelatedFieldMaker:
from ..sugar import from_random
scm = 1.
for a in self._a:
op = a.fluctuation_amplitude*self._azm.one_over()
op = a.fluctuation_amplitude*self._azm.ptw("reciprocal")
res = np.array([op(from_random('normal', op.domain)).val
for _ in range(nsamples)])
scm *= res**2 + 1.
......@@ -573,9 +573,9 @@ class CorrelatedFieldMaker:
return self.average_fluctuation(0)
q = 1.
for a in self._a:
fl = a.fluctuation_amplitude*self._azm.one_over()
fl = a.fluctuation_amplitude*self._azm.ptw("reciprocal")
q = q*(Adder(full(fl.target, 1.)) @ fl**2)
return (Adder(full(q.target, -1.)) @ q).sqrt()*self._azm
return (Adder(full(q.target, -1.)) @ q).ptw("sqrt")*self._azm
def slice_fluctuation(self, space):
"""Returns operator which acts on prior or posterior samples"""
......@@ -587,12 +587,12 @@ class CorrelatedFieldMaker:
return self.average_fluctuation(0)
q = 1.
for j in range(len(self._a)):
fl = self._a[j].fluctuation_amplitude*self._azm.one_over()
fl = self._a[j].fluctuation_amplitude*self._azm.ptw("reciprocal")
if j == space:
q = q*fl**2
else:
q = q*(Adder(full(fl.target, 1.)) @ fl**2)
return q.sqrt()*self._azm
return q.ptw("sqrt")*self._azm
def average_fluctuation(self, space):
"""Returns operator which acts on prior or posterior samples"""
......
......@@ -97,9 +97,9 @@ def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, c
m = CentralPadd.adjoint(FFTB(Sm(m)))
ops['smoothed_dynamics'] = m
m = -m.log()
m = -m.ptw("log")
if not minimum_phase:
m = m.exp()
m = m.ptw("exp")
if causal or minimum_phase:
m = Real.adjoint(FFT.inverse(Realizer(FFT.target).adjoint(m)))
kernel = makeOp(
......@@ -114,19 +114,19 @@ def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, c
c = FieldAdapter(UnstructuredDomain(len(sigc)), keys[1])
c = makeOp(Field(c.target, np.array(sigc)))(c)
lightspeed = ScalingOperator(c.target, -0.5)(c).exp()
lightspeed = ScalingOperator(c.target, -0.5)(c).ptw("exp")
scaling = np.array(m.target[0].distances[1:])/m.target[0].distances[0]
scaling = DiagonalOperator(Field(c.target, scaling))
ops['lightspeed'] = scaling(lightspeed)
c = LightConeOperator(c.target, m.target, quant) @ c.exp()
c = LightConeOperator(c.target, m.target, quant) @ c.ptw("exp")
ops['light_cone'] = c
m = c*m
if causal or minimum_phase:
m = FFT(Real(m))
if minimum_phase:
m = m.exp()
m = m.ptw("exp")
return m, ops
......
......@@ -120,7 +120,7 @@ def InverseGammaOperator(domain, alpha, q, delta=1e-2):
Distance between sampling points for linear interpolation.
"""
op = _InterpolationOperator(domain, lambda x: invgamma.ppf(norm._cdf(x), float(alpha)),
-8.2, 8.2, delta, lambda x: x.log(), lambda x: x.exp())
-8.2, 8.2, delta, lambda x: x.ptw("log"), lambda x: x.ptw("exp"))
if np.isscalar(q):
return op.scale(q)
return makeOp(q) @ op
......
......@@ -166,7 +166,7 @@ class Linearization(object):
return self.__mul__(1./other)
def __rtruediv__(self, other):
return self.one_over().__mul__(other)
return self.reciprocal().__mul__(other)
def __pow__(self, power):
if not np.isscalar(power):
......@@ -282,9 +282,10 @@ class Linearization(object):
self._val.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def exp(self):
tmp = self._val.exp()
return self.new(tmp, makeOp(tmp)(self._jac))
def ptw(self, op):
from .pointwise import ptw_dict
t1, t2 = self._val.ptw(op, True)
return self.new(t1, makeOp(t2)(self._jac))
def clip(self, min=None, max=None):
tmp = self._val.clip(min, max)
......@@ -298,90 +299,6 @@ class Linearization(object):
tmp2 = makeOp(1. - (tmp == min) - (tmp == max))
return self.new(tmp, tmp2(self._jac))
def sqrt(self):
tmp = self._val.sqrt()
return self.new(tmp, makeOp(0.5/tmp)(self._jac))
def sin(self):
tmp = self._val.sin()
tmp2 = self._val.cos()
return self.new(tmp, makeOp(tmp2)(self._jac))
def cos(self):
tmp = self._val.cos()
tmp2 = - self._val.sin()
return self.new(tmp, makeOp(tmp2)(self._jac))
def tan(self):
tmp = self._val.tan()
tmp2 = 1./(self._val.cos()**2)
return self.new(tmp, makeOp(tmp2)(self._jac))
def sinc(self):
tmp = self._val.sinc()
tmp2 = ((np.pi*self._val).cos()-tmp)/self._val
ind = self._val.val == 0
loc = tmp2.val_rw()
loc[ind] = 0
tmp2 = Field(tmp.domain, loc)
return self.new(tmp, makeOp(tmp2)(self._jac))
def log(self):
tmp = self._val.log()
return self.new(tmp, makeOp(1./self._val)(self._jac))
def log10(self):
tmp = self._val.log10()
tmp2 = 1. / (self._val * np.log(10))
return self.new(tmp, makeOp(tmp2)(self._jac))
def log1p(self):
tmp = self._val.log1p()
tmp2 = 1. / (1. + self._val)
return self.new(tmp, makeOp(tmp2)(self.jac))
def expm1(self):
tmp = self._val.expm1()
tmp2 = self._val.exp()
return self.new(tmp, makeOp(tmp2)(self.jac))
def sinh(self):
tmp = self._val.sinh()
tmp2 = self._val.cosh()
return self.new(tmp, makeOp(tmp2)(self._jac))
def cosh(self):
tmp = self._val.cosh()
tmp2 = self._val.sinh()
return self.new(tmp, makeOp(tmp2)(self._jac))
def tanh(self):
tmp = self._val.tanh()
return self.new(tmp, makeOp(1.-tmp**2)(self._jac))
def sigmoid(self):
tmp = self._val.tanh()
tmp2 = 0.5*(1.+tmp)
return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
def absolute(self):
if utilities.iscomplextype(self._val.dtype):
raise TypeError("Argument must not be complex")
tmp = self._val.absolute()
tmp2 = self._val.sign()
ind = self._val.val == 0
loc = tmp2.val_rw().astype(float)
loc[ind] = np.nan
tmp2 = Field(tmp.domain, loc)
return self.new(tmp, makeOp(tmp2)(self._jac))
def one_over(self):
tmp = 1./self._val
tmp2 = - tmp/self._val
return self.new(tmp, makeOp(tmp2)(self._jac))
def add_metric(self, metric):
return self.new(self._val, self._jac, metric)
......
......@@ -310,8 +310,12 @@ class MultiField(object):
res[key] = -val if neg else val
return MultiField.from_dict(res)
def one_over(self):
return 1/self
def ptw(self, op, with_deriv=False):
tmp = tuple(val.ptw(op, with_deriv) for val in self.values())
if with_deriv:
return (MultiField(self.domain, tuple(v[0] for v in tmp)),
MultiField(self.domain, tuple(v[1] for v in tmp)))
return MultiField(self.domain, tmp)
def _binary_op(self, other, op):
f = getattr(Field, op)
......@@ -347,14 +351,3 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported")
return func2
setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh",
"absolute", "sinc", "sign", "log10", "log1p", "expm1"]:
def func(f):
def func2(self):
fu = getattr(Field, f)
return MultiField(self.domain,
tuple(fu(val) for val in self.values()))
return func2
setattr(MultiField, f, func(f))
......@@ -127,7 +127,7 @@ class VariableCovarianceGaussianEnergy(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].log().sum())
res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].ptw("log").sum())
if not isinstance(x, Linearization) or not x.want_metric:
return res
mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)}
......@@ -229,7 +229,7 @@ class PoissonianEnergy(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = x.sum() - x.log().vdot(self._d)
res = x.sum() - x.ptw("log").vdot(self._d)
if not isinstance(x, Linearization) or not x.want_metric:
return res
return res.add_metric(makeOp(1./x.val))
......@@ -269,7 +269,7 @@ class InverseGammaLikelihood(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = x.log().vdot(self._alphap1) + x.one_over().vdot(self._beta)
res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta)
if not isinstance(x, Linearization) or not x.want_metric:
return res
return res.add_metric(makeOp(self._alphap1/(x.val**2)))
......@@ -298,7 +298,7 @@ class StudentTEnergy(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum()
res = ((self._theta+1)/2)*(x**2/self._theta).ptw("log1p").sum()
if not isinstance(x, Linearization) or not x.want_metric:
return res
met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3))
......@@ -332,7 +332,7 @@ class BernoulliEnergy(EnergyOperator):
def apply(self, x):
self._check_input(x)
res = -x.log().vdot(self._d) + (1.-x).log().vdot(self._d-1.)
res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.)
if not isinstance(x, Linearization) or not x.want_metric:
return res
return res.add_metric(makeOp(1./(x.val*(1. - x.val))))
......
......@@ -222,15 +222,8 @@ class Operator(metaclass=NiftyMeta):
def _simplify_for_constant_input_nontrivial(self, c_inp):
return None, self
for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh",
"sinc", "sigmoid", "absolute", "one_over", "log10", "log1p", "expm1"]:
def func(f):
def func2(self):
fa = _FunctionApplier(self.target, f)
return _OpChain.make((fa, self))
return func2
setattr(Operator, f, func(f))
def ptw(self, op):
return _OpChain.make((_FunctionApplier(self.target, op), self))
class _ConstCollector(object):
......@@ -301,7 +294,7 @@ class _FunctionApplier(Operator):
def apply(self, x):
self._check_input(x)
return getattr(x, self._funcname)()
return x.ptw(self._funcname)
class _Clipper(Operator):
......
......@@ -37,10 +37,7 @@ from .plot import Plot
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
'full', 'makeField',
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'sigmoid',
'sin', 'cos', 'tan', 'sinh', 'cosh', 'log10',
'absolute', 'one_over', 'clip', 'sinc', "log1p", "expm1",
'conjugate', 'get_signal_variance', 'makeOp', 'domain_union',
'makeDomain', 'get_signal_variance', 'makeOp', 'domain_union',
'get_default_codomain', 'single_plot', 'exec_time',
'calculate_position']
......@@ -366,26 +363,6 @@ def domain_union(domains):
return MultiDomain.union(domains)
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "log10", "tanh", "sigmoid",
"conjugate", 'sin', 'cos', 'tan', 'sinh', 'cosh',
'absolute', 'one_over', 'sinc', 'log1p', 'expm1']:
def func(f):
def func2(x):
from .linearization import Linearization
from .operators.operator import Operator
if isinstance(x, (Field, MultiField, Linearization, Operator)):
return getattr(x, f)()
else:
return getattr(np, f)(x)
return func2
setattr(_current_module, f, func(f))
def clip(a, a_min=None, a_max=None):
return a.clip(a_min, a_max)
......
......@@ -77,7 +77,7 @@ def test_studentt(field):
def test_hamiltonian_and_KL(field):
field = field.exp()
field = field.ptw("exp")
space = field.domain
lh = ift.GaussianEnergy(domain=space)
hamiltonian = ift.StandardHamiltonian(lh)
......@@ -91,7 +91,7 @@ def test_hamiltonian_and_KL(field):
def test_variablecovariancegaussian(field):
if isinstance(field.domain, ift.MultiDomain):
return
dc = {'a': field, 'b': field.exp()}
dc = {'a': field, 'b': field.ptw("exp")}
mf = ift.MultiField.from_dict(dc)
energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b')
ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6)
......@@ -101,7 +101,7 @@ def test_variablecovariancegaussian(field):
def test_inverse_gamma(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.exp()
field = field.ptw("exp")
space = field.domain
d = ift.random.current_rng().normal(10, size=space.shape)**2
d = ift.Field(space, d)
......@@ -112,7 +112,7 @@ def test_inverse_gamma(field):
def testPoissonian(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.exp()
field = field.ptw("exp")
space = field.domain
d = ift.random.current_rng().poisson(120, size=space.shape)
d = ift.Field(space, d)
......@@ -123,7 +123,7 @@ def testPoissonian(field):
def test_bernoulli(field):
if isinstance(field.domain, ift.MultiDomain):
return
field = field.sigmoid()
field = field.ptw("sigmoid")
space = field.domain
d = ift.random.current_rng().binomial(1, 0.1, size=space.shape)
d = ift.Field(space, d)
......
......@@ -198,7 +198,7 @@ def test_trivialities():
assert_equal(f1.val, f1.real.val)
assert_equal(f1.val, (+f1).val)
f1 = ift.Field.full(s1, 27. + 3j)
assert_equal(f1.one_over().val, (1./f1).val)
assert_equal(f1.ptw("reciprocal").val, (1./f1).val)
assert_equal(f1.real.val, 27.)
assert_equal(f1.imag.val, 3.)
assert_equal(f1.s_sum(), f1.sum(0).val)
......@@ -336,7 +336,7 @@ def test_emptydomain():
def test_funcs(num, dom, func):
num = 5
f = ift.Field.full(dom, num)
res = getattr(f, func)()
res = f.ptw(func)
res2 = getattr(np, func)(num)
assert_allclose(res.val, res2)
......
......@@ -51,7 +51,7 @@ def test_gaussian_energy(space, nonlinearity, noise, seed):
return 1/(1 + k**2)**dim
pspec = ift.PS_field(pspace, pspec)
A = Dist(ift.sqrt(pspec))
A = Dist(pspec.ptw("sqrt"))
N = ift.ScalingOperator(space, noise)
n = N.draw_sample()
R = ift.ScalingOperator(space, 10.)
......@@ -61,7 +61,7 @@ def test_gaussian_energy(space, nonlinearity, noise, seed):
return R @ ht @ ift.makeOp(A)
else:
tmp = ht @ ift.makeOp(A)
nonlin = getattr(tmp, nonlinearity)()
nonlin = tmp.ptw(nonlinearity)
return R @ nonlin
d = d_model()(xi0) + n
......
......@@ -43,19 +43,19 @@ def test_special_gradients():
jt(var.clip(-1, 0), np.zeros_like(s))
assert_allclose(
_lin2grad(ift.Linearization.make_var(0*f).sinc()), np.zeros(s.shape))
assert_(np.isnan(_lin2grad(ift.Linearization.make_var(0*f).absolute())))
_lin2grad(ift.Linearization.make_var(0*f).ptw("sinc")), np.zeros(s.shape))
assert_(np.isnan(_lin2grad(ift.Linearization.make_var(0*f).ptw("abs"))))
assert_allclose(
_lin2grad(ift.Linearization.make_var(0*f + 10).absolute()),
_lin2grad(ift.Linearization.make_var(0*f + 10).ptw("abs")),
np.ones(s.shape))
assert_allclose(
_lin2grad(ift.Linearization.make_var(0*f - 10).absolute()),
_lin2grad(ift.Linearization.make_var(0*f - 10).ptw("abs")),
-np.ones(s.shape))
@pmp('f', [
'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh',
'absolute', 'one_over', 'sigmoid', 'log10', 'log1p', "expm1"
'absolute', 'reciprocal', 'sigmoid', 'log10', 'log1p', "expm1"
])
def test_actual_gradients(f):
dom = ift.UnstructuredDomain((1,))
......@@ -63,8 +63,8 @@ def test_actual_gradients(f):
eps = 1e-8
var0 = ift.Linearization.make_var(fld)
var1 = ift.Linearization.make_var(fld + eps)
f0 = getattr(var0, f)().val.val
f1 = getattr(var1, f)().val.val
f0 = var0.ptw(f).val.val
f1 = var1.ptw(f).val.val
df0 = (f1 - f0)/eps
df1 = _lin2grad(getattr(var0, f)())
df1 = _lin2grad(var0.ptw(f))
assert_allclose(df0, df1, rtol=100*eps)
......@@ -33,7 +33,7 @@ def test_vdot():
def test_func():
f1 = ift.from_random("normal", domain=dom, dtype=np.complex128)
assert_allclose(
ift.log(ift.exp((f1)))["d1"].val, f1["d1"].val)
f1.ptw("exp").ptw("log")["d1"].val, f1["d1"].val)
def test_multifield_field_consistency():
......