Commit 8dca6e3b authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'pointwise' into 'NIFTy_6'

Rework of pointwise operations

See merge request !440
parents 9974ac4e 55ec681a
Pipeline #72716 passed with stages
in 19 minutes and 56 seconds
......@@ -142,8 +142,7 @@ class RGSpace(StructuredDomain):
@staticmethod
def _kernel(x, sigma):
from ..sugar import exp
return exp(x*x * (-2.*np.pi*np.pi*sigma*sigma))
return (x*x * (-2.*np.pi*np.pi*sigma*sigma)).ptw("exp")
def get_fft_smoothing_kernel_function(self, sigma):
if (not self.harmonic):
......
......@@ -20,9 +20,10 @@ import numpy as np
from . import utilities
from .domain_tuple import DomainTuple
from .operators.operator import Operator
class Field(object):
class Field(Operator):
"""The discrete representation of a continuous field over multiple spaces.
Stores data arrays and carries all the needed meta-information (i.e. the
......@@ -634,10 +635,9 @@ class Field(object):
Field
The result of the operation.
"""
from .sugar import sqrt
if self.scalar_weight(spaces) is not None:
return self._contraction_helper('std', spaces)
return sqrt(self.var(spaces))
return self.var(spaces).ptw("sqrt")
def s_std(self):
"""Determines the standard deviation of the Field.
......@@ -677,17 +677,6 @@ class Field(object):
def flexible_addsub(self, other, neg):
return self-other if neg else self+other
def sigmoid(self):
return 0.5*(1.+self.tanh())
def clip(self, min=None, max=None):
min = min.val if isinstance(min, Field) else min
max = max.val if isinstance(max, Field) else max
return Field(self._domain, np.clip(self._val, min, max))
def one_over(self):
return 1/self
def _binary_op(self, other, op):
# if other is a field, make sure that the domains match
f = getattr(self._val, op)
......@@ -699,6 +688,26 @@ class Field(object):
return Field(self._domain, f(other))
return NotImplemented
def _prep_args(self, args, kwargs):
for arg in args + tuple(kwargs.values()):
if not (arg is None or np.isscalar(arg) or arg.jac is None):
raise TypeError("bad argument")
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val
for key, val in kwargs.items()}
return argstmp, kwargstmp
def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict
argstmp, kwargstmp = self._prep_args(args, kwargs)
return Field(self._domain, ptw_dict[op][0](self._val, *argstmp, **kwargstmp))
def ptw_with_deriv(self, op, *args, **kwargs):
from .pointwise import ptw_dict
argstmp, kwargstmp = self._prep_args(args, kwargs)
tmp = ptw_dict[op][1](self._val, *argstmp, **kwargstmp)
return (Field(self._domain, tmp[0]), Field(self._domain, tmp[1]))
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......@@ -721,11 +730,3 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported")
return func2
setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh",
"absolute", "sinc", "sign", "log10", "log1p", "expm1"]:
def func(f):
def func2(self):
return Field(self._domain, getattr(np, f)(self.val))
return func2
setattr(Field, f, func(f))
......@@ -126,7 +126,7 @@ class _LognormalMomentMatching(Operator):
logmean, logsig = _lognormal_moments(mean, sig, N_copies)
self._mean = mean
self._sig = sig
op = _normal(logmean, logsig, key, N_copies).exp()
op = _normal(logmean, logsig, key, N_copies).ptw("exp")
self._domain, self._target = op.domain, op.target
self.apply = op.apply
......@@ -224,8 +224,8 @@ class _Normalization(Operator):
def apply(self, x):
self._check_input(x)
amp = x.exp()
spec = (2*x).exp()
amp = x.ptw("exp")
spec = amp**2
# FIXME This normalizes also the zeromode which is supposed to be left
# untouched by this operator
return self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp
......@@ -332,17 +332,17 @@ class _Amplitude(Operator):
sig_fluc = vol1 @ ps_expander @ fluctuations
xi = ducktape(dom, None, key)
sigma = sig_flex*(Adder(shift) @ sig_asp).sqrt()
sigma = sig_flex*(Adder(shift) @ sig_asp).ptw("sqrt")
smooth = _SlopeRemover(target, space) @ twolog @ (sigma*xi)
op = _Normalization(target, space) @ (slope + smooth)
if N_copies > 0:
op = Distributor @ op
sig_fluc = Distributor @ sig_fluc
op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op)
op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.ptw("reciprocal"))*op)
self._fluc = (_Distributor(dofdex, fluctuations.target,
distributed_tgt[0]) @ fluctuations)
else:
op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.one_over())*op)
op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.ptw("reciprocal"))*op)
self._fluc = fluctuations
self.apply = op.apply
......@@ -527,7 +527,7 @@ class CorrelatedFieldMaker:
for _ in range(prior_info):
sc.add(op(from_random('normal', op.domain)))
mean = sc.mean.val
stddev = sc.var.sqrt().val
stddev = sc.var.ptw("sqrt").val
for m, s in zip(mean.flatten(), stddev.flatten()):
logger.info('{}: {:.02E} ± {:.02E}'.format(kk, m, s))
......@@ -539,7 +539,7 @@ class CorrelatedFieldMaker:
from ..sugar import from_random
scm = 1.
for a in self._a:
op = a.fluctuation_amplitude*self._azm.one_over()
op = a.fluctuation_amplitude*self._azm.ptw("reciprocal")
res = np.array([op(from_random('normal', op.domain)).val
for _ in range(nsamples)])
scm *= res**2 + 1.
......@@ -573,9 +573,9 @@ class CorrelatedFieldMaker:
return self.average_fluctuation(0)
q = 1.
for a in self._a:
fl = a.fluctuation_amplitude*self._azm.one_over()
fl = a.fluctuation_amplitude*self._azm.ptw("reciprocal")
q = q*(Adder(full(fl.target, 1.)) @ fl**2)
return (Adder(full(q.target, -1.)) @ q).sqrt()*self._azm
return (Adder(full(q.target, -1.)) @ q).ptw("sqrt")*self._azm
def slice_fluctuation(self, space):
"""Returns operator which acts on prior or posterior samples"""
......@@ -587,12 +587,12 @@ class CorrelatedFieldMaker:
return self.average_fluctuation(0)
q = 1.
for j in range(len(self._a)):
fl = self._a[j].fluctuation_amplitude*self._azm.one_over()
fl = self._a[j].fluctuation_amplitude*self._azm.ptw("reciprocal")
if j == space:
q = q*fl**2
else:
q = q*(Adder(full(fl.target, 1.)) @ fl**2)
return q.sqrt()*self._azm
return q.ptw("sqrt")*self._azm
def average_fluctuation(self, space):
"""Returns operator which acts on prior or posterior samples"""
......
......@@ -97,9 +97,9 @@ def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, c
m = CentralPadd.adjoint(FFTB(Sm(m)))
ops['smoothed_dynamics'] = m
m = -m.log()
m = -m.ptw("log")
if not minimum_phase:
m = m.exp()
m = m.ptw("exp")
if causal or minimum_phase:
m = Real.adjoint(FFT.inverse(Realizer(FFT.target).adjoint(m)))
kernel = makeOp(
......@@ -114,19 +114,19 @@ def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, c
c = FieldAdapter(UnstructuredDomain(len(sigc)), keys[1])
c = makeOp(Field(c.target, np.array(sigc)))(c)
lightspeed = ScalingOperator(c.target, -0.5)(c).exp()
lightspeed = ScalingOperator(c.target, -0.5)(c).ptw("exp")
scaling = np.array(m.target[0].distances[1:])/m.target[0].distances[0]
scaling = DiagonalOperator(Field(c.target, scaling))
ops['lightspeed'] = scaling(lightspeed)
c = LightConeOperator(c.target, m.target, quant) @ c.exp()
c = LightConeOperator(c.target, m.target, quant) @ c.ptw("exp")
ops['light_cone'] = c
m = c*m
if causal or minimum_phase:
m = FFT(Real(m))
if minimum_phase:
m = m.exp()
m = m.ptw("exp")
return m, ops
......
......@@ -19,7 +19,6 @@ import numpy as np
from ..domain_tuple import DomainTuple
from ..field import Field
from ..linearization import Linearization
from ..operators.linear_operator import LinearOperator
from ..operators.operator import Operator
......@@ -132,7 +131,7 @@ class LightConeOperator(Operator):
self._sigx = sigx
def apply(self, x):
lin = isinstance(x, Linearization)
lin = x.jac is not None
a, derivs = _cone_arrays(x.val.val if lin else x.val, self.target, self._sigx, lin)
res = Field(self.target, a)
if not lin:
......
......@@ -22,7 +22,6 @@ from scipy.interpolate import CubicSpline
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..linearization import Linearization
from ..operators.operator import Operator
from ..sugar import makeOp
from .. import random
......@@ -79,7 +78,7 @@ class _InterpolationOperator(Operator):
def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization)
lin = x.jac is not None
xval = x.val.val if lin else x.val
res = self._interpolator(xval)
res = Field(self._domain, res)
......@@ -120,7 +119,7 @@ def InverseGammaOperator(domain, alpha, q, delta=1e-2):
Distance between sampling points for linear interpolation.
"""
op = _InterpolationOperator(domain, lambda x: invgamma.ppf(norm._cdf(x), float(alpha)),
-8.2, 8.2, delta, lambda x: x.log(), lambda x: x.exp())
-8.2, 8.2, delta, lambda x: x.ptw("log"), lambda x: x.ptw("exp"))
if np.isscalar(q):
return op.scale(q)
return makeOp(q) @ op
......@@ -148,7 +147,7 @@ class UniformOperator(Operator):
def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization)
lin = x.jac is not None
xval = x.val.val if lin else x.val
res = Field(self._target, self._scale*norm._cdf(xval) + self._loc)
if not lin:
......
......@@ -17,13 +17,12 @@
import numpy as np
from .field import Field
from .multi_field import MultiField
from .sugar import makeOp
from . import utilities
from .operators.operator import Operator
class Linearization(object):
class Linearization(Operator):
"""Let `A` be an operator and `x` a field. `Linearization` stores the value
of the operator application (i.e. `A(x)`), the local Jacobian
(i.e. `dA(x)/dx`) and, optionally, the local metric.
......@@ -64,7 +63,7 @@ class Linearization(object):
return Linearization(val, jac, metric, self._want_metric)
def trivial_jac(self):
return Linearization.make_var(self._val, self._want_metric)
return self.make_var(self._val, self._want_metric)
def prepend_jac(self, jac):
metric = None
......@@ -101,6 +100,7 @@ class Linearization(object):
-----
Only available if target is a scalar
"""
from .field import Field
return self._jac.adjoint_times(Field.scalar(1.))
@property
......@@ -135,18 +135,15 @@ class Linearization(object):
return self.new(self._val.real, self._jac.real)
def _myadd(self, other, neg):
if isinstance(other, Linearization):
met = None
if self._metric is not None and other._metric is not None:
met = self._metric._myadd(other._metric, neg)
return self.new(
self._val.flexible_addsub(other._val, neg),
self._jac._myadd(other._jac, neg), met)
if isinstance(other, (int, float, complex, Field, MultiField)):
if neg:
return self.new(self._val-other, self._jac, self._metric)
else:
return self.new(self._val+other, self._jac, self._metric)
if np.isscalar(other) or other.jac is None:
return self.new(self._val-other if neg else self._val+other,
self._jac, self._metric)
met = None
if self._metric is not None and other._metric is not None:
met = self._metric._myadd(other._metric, neg)
return self.new(
self.val.flexible_addsub(other.val, neg),
self.jac._myadd(other.jac, neg), met)
def __add__(self, other):
return self._myadd(other, False)
......@@ -161,37 +158,35 @@ class Linearization(object):
return (-self).__add__(other)
def __truediv__(self, other):
if isinstance(other, Linearization):
return self.__mul__(other.one_over())
return self.__mul__(1./other)
if np.isscalar(other):
return self.__mul__(1/other)
return self.__mul__(other.ptw("reciprocal"))
def __rtruediv__(self, other):
return self.one_over().__mul__(other)
return self.ptw("reciprocal").__mul__(other)
def __pow__(self, power):
if not np.isscalar(power):
if not (np.isscalar(power) or power.jac is None):
return NotImplemented
return self.new(self._val**power,
makeOp(self._val**(power-1)).scale(power)(self._jac))
return self.ptw("power", power)
def __mul__(self, other):
from .sugar import makeOp
if isinstance(other, Linearization):
if self.target != other.target:
raise ValueError("domain mismatch")
return self.new(
self._val*other._val,
(makeOp(other._val)(self._jac))._myadd(
makeOp(self._val)(other._jac), False))
if np.isscalar(other):
if other == 1:
return self
met = None if self._metric is None else self._metric.scale(other)
return self.new(self._val*other, self._jac.scale(other), met)
if isinstance(other, (Field, MultiField)):
from .sugar import makeOp
if other.jac is None:
if self.target != other.domain:
raise ValueError("domain mismatch")
return self.new(self._val*other, makeOp(other)(self._jac))
if self.target != other.target:
raise ValueError("domain mismatch")
return self.new(
self.val*other.val,
(makeOp(other.val)(self.jac))._myadd(
makeOp(self.val)(other.jac), False))
def __rmul__(self, other):
return self.__mul__(other)
......@@ -209,17 +204,16 @@ class Linearization(object):
Linearization
the outer product of self and other
"""
from .operators.outer_product_operator import OuterProduct
if isinstance(other, Linearization):
return self.new(
OuterProduct(self._val, other.target)(other._val),
OuterProduct(self._jac(self._val), other.target)._myadd(
OuterProduct(self._val, other.target)(other._jac), False))
if np.isscalar(other):
return self.__mul__(other)
if isinstance(other, (Field, MultiField)):
from .operators.outer_product_operator import OuterProduct
if other.jac is None:
return self.new(OuterProduct(self._val, other.domain)(other),
OuterProduct(self._jac(self._val), other.domain))
return self.new(
OuterProduct(self._val, other.target)(other._val),
OuterProduct(self._jac(self._val), other.target)._myadd(
OuterProduct(self._val, other.target)(other._jac), False))
def vdot(self, other):
"""Computes the inner product of this Linearization with a Field or
......@@ -235,7 +229,7 @@ class Linearization(object):
the inner product of self and other
"""
from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)):
if other.jac is None:
return self.new(
self._val.vdot(other),
VdotOperator(other)(self._jac))
......@@ -282,105 +276,10 @@ class Linearization(object):
self._val.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def exp(self):
tmp = self._val.exp()
return self.new(tmp, makeOp(tmp)(self._jac))
def clip(self, min=None, max=None):
tmp = self._val.clip(min, max)
if (min is None) and (max is None):
return self
elif max is None:
tmp2 = makeOp(1. - (tmp == min))
elif min is None:
tmp2 = makeOp(1. - (tmp == max))
else:
tmp2 = makeOp(1. - (tmp == min) - (tmp == max))
return self.new(tmp, tmp2(self._jac))
def sqrt(self):
tmp = self._val.sqrt()
return self.new(tmp, makeOp(0.5/tmp)(self._jac))
def sin(self):
tmp = self._val.sin()
tmp2 = self._val.cos()
return self.new(tmp, makeOp(tmp2)(self._jac))
def cos(self):
tmp = self._val.cos()
tmp2 = - self._val.sin()
return self.new(tmp, makeOp(tmp2)(self._jac))
def tan(self):
tmp = self._val.tan()
tmp2 = 1./(self._val.cos()**2)
return self.new(tmp, makeOp(tmp2)(self._jac))
def sinc(self):
tmp = self._val.sinc()
tmp2 = ((np.pi*self._val).cos()-tmp)/self._val
ind = self._val.val == 0
loc = tmp2.val_rw()
loc[ind] = 0
tmp2 = Field(tmp.domain, loc)
return self.new(tmp, makeOp(tmp2)(self._jac))
def log(self):
tmp = self._val.log()
return self.new(tmp, makeOp(1./self._val)(self._jac))
def log10(self):
tmp = self._val.log10()
tmp2 = 1. / (self._val * np.log(10))
return self.new(tmp, makeOp(tmp2)(self._jac))
def log1p(self):
tmp = self._val.log1p()
tmp2 = 1. / (1. + self._val)
return self.new(tmp, makeOp(tmp2)(self.jac))
def expm1(self):
tmp = self._val.expm1()
tmp2 = self._val.exp()
return self.new(tmp, makeOp(tmp2)(self.jac))
def sinh(self):
tmp = self._val.sinh()
tmp2 = self._val.cosh()
return self.new(tmp, makeOp(tmp2)(self._jac))
def cosh(self):
tmp = self._val.cosh()
tmp2 = self._val.sinh()
return self.new(tmp, makeOp(tmp2)(self._jac))
def tanh(self):
tmp = self._val.tanh()
return self.new(tmp, makeOp(1.-tmp**2)(self._jac))
def sigmoid(self):
tmp = self._val.tanh()
tmp2 = 0.5*(1.+tmp)
return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac))
def absolute(self):
if utilities.iscomplextype(self._val.dtype):
raise TypeError("Argument must not be complex")
tmp = self._val.absolute()
tmp2 = self._val.sign()
ind = self._val.val == 0
loc = tmp2.val_rw().astype(float)
loc[ind] = np.nan
tmp2 = Field(tmp.domain, loc)
return self.new(tmp, makeOp(tmp2)(self._jac))
def one_over(self):
tmp = 1./self._val
tmp2 = - tmp/self._val
return self.new(tmp, makeOp(tmp2)(self._jac))
def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict
t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs)
return self.new(t1, makeOp(t2)(self._jac))
def add_metric(self, metric):
return self.new(self._val, self._jac, metric)
......
......@@ -21,9 +21,10 @@ from . import utilities
from .field import Field
from .multi_domain import MultiDomain
from .domain_tuple import DomainTuple
from .operators.operator import Operator
class MultiField(object):
class MultiField(Operator):
def __init__(self, domain, val):
"""The discrete representation of a continuous field over a sum space.
......@@ -199,13 +200,8 @@ class MultiField(object):
def conjugate(self):
return self._transform(lambda x: x.conjugate())
def clip(self, min=None, max=None):
ncomp = len(self._val)
lmin = min._val if isinstance(min, MultiField) else (min,)*ncomp
lmax = max._val if isinstance(max, MultiField) else (max,)*ncomp
return MultiField(
self._domain,
tuple(self._val[i].clip(lmin[i], lmax[i]) for i in range(ncomp)))
def clip(self, a_min=None, a_max=None):
return self.ptw("clip", a_min, a_max)
def s_all(self):
for v in self._val:
......@@ -310,8 +306,30 @@ class MultiField(object):
res[key] = -val if neg else val
return MultiField.from_dict(res)
def one_over(self):
return 1/self
def _prep_args(self, args, kwargs, i):
for arg in args + tuple(kwargs.values()):
if not (arg is None or np.isscalar(arg) or arg.jac is None):
raise TypeError("bad argument")
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i]
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i]
for key, val in kwargs.items()}
return argstmp, kwargstmp
def ptw(self, op, *args, **kwargs):
tmp = []
for i in range(len(self._val)):
argstmp, kwargstmp = self._prep_args(args, kwargs, i)
tmp.append(self._val[i].ptw(op, *argstmp, **kwargstmp))
return MultiField(self.domain, tuple(tmp))
def ptw_with_deriv(self, op, *args, **kwargs):
tmp = []
for i in range(len(self._val)):
argstmp, kwargstmp = self._prep_args(args, kwargs, i)
tmp.append(self._val[i].ptw_with_deriv(op, *argstmp, **kwargstmp))
return (MultiField(self.domain, tuple(v[0] for v in tmp)),
MultiField(self.domain, tuple(v[1] for v in tmp)))
def _binary_op(self, other, op):
f = getattr(Field, op)
......@@ -347,14 +365,3 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported")
return func2
setattr(MultiField, op, func(op))
for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh",
"absolute", "sinc", "sign", "log10", "log1p", "expm1"]:
def func(f):
def func2(self):
fu = getattr(Field, f)
return MultiField(self.domain,
tuple(fu(val) for val in self.values()))
return func2
setattr(MultiField, f, func(f))
......@@ -20,7 +20,6 @@ import numpy as np
from .. import utilities
from ..domain_tuple import DomainTuple
from ..field import Field
from ..linearization import Linearization
from ..multi_domain import MultiDomain
from ..multi_field import MultiField
from ..sugar import makeDomain, makeOp
......@@ -59,9 +58,8 @@ class Squared2NormOperator(EnergyOperator):
def apply(self, x):
self._check_input(x)