Commit 8dca6e3b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'pointwise' into 'NIFTy_6'

Rework of pointwise operations

See merge request !440
parents 9974ac4e 55ec681a
Pipeline #72716 passed with stages
in 19 minutes and 56 seconds
...@@ -142,8 +142,7 @@ class RGSpace(StructuredDomain): ...@@ -142,8 +142,7 @@ class RGSpace(StructuredDomain):
@staticmethod @staticmethod
def _kernel(x, sigma): def _kernel(x, sigma):
from ..sugar import exp return (x*x * (-2.*np.pi*np.pi*sigma*sigma)).ptw("exp")
return exp(x*x * (-2.*np.pi*np.pi*sigma*sigma))
def get_fft_smoothing_kernel_function(self, sigma): def get_fft_smoothing_kernel_function(self, sigma):
if (not self.harmonic): if (not self.harmonic):
......
...@@ -20,9 +20,10 @@ import numpy as np ...@@ -20,9 +20,10 @@ import numpy as np
from . import utilities from . import utilities
from .domain_tuple import DomainTuple from .domain_tuple import DomainTuple
from .operators.operator import Operator
class Field(object): class Field(Operator):
"""The discrete representation of a continuous field over multiple spaces. """The discrete representation of a continuous field over multiple spaces.
Stores data arrays and carries all the needed meta-information (i.e. the Stores data arrays and carries all the needed meta-information (i.e. the
...@@ -634,10 +635,9 @@ class Field(object): ...@@ -634,10 +635,9 @@ class Field(object):
Field Field
The result of the operation. The result of the operation.
""" """
from .sugar import sqrt
if self.scalar_weight(spaces) is not None: if self.scalar_weight(spaces) is not None:
return self._contraction_helper('std', spaces) return self._contraction_helper('std', spaces)
return sqrt(self.var(spaces)) return self.var(spaces).ptw("sqrt")
def s_std(self): def s_std(self):
"""Determines the standard deviation of the Field. """Determines the standard deviation of the Field.
...@@ -677,17 +677,6 @@ class Field(object): ...@@ -677,17 +677,6 @@ class Field(object):
def flexible_addsub(self, other, neg): def flexible_addsub(self, other, neg):
return self-other if neg else self+other return self-other if neg else self+other
def sigmoid(self):
return 0.5*(1.+self.tanh())
def clip(self, min=None, max=None):
min = min.val if isinstance(min, Field) else min
max = max.val if isinstance(max, Field) else max
return Field(self._domain, np.clip(self._val, min, max))
def one_over(self):
return 1/self
def _binary_op(self, other, op): def _binary_op(self, other, op):
# if other is a field, make sure that the domains match # if other is a field, make sure that the domains match
f = getattr(self._val, op) f = getattr(self._val, op)
...@@ -699,6 +688,26 @@ class Field(object): ...@@ -699,6 +688,26 @@ class Field(object):
return Field(self._domain, f(other)) return Field(self._domain, f(other))
return NotImplemented return NotImplemented
def _prep_args(self, args, kwargs):
for arg in args + tuple(kwargs.values()):
if not (arg is None or np.isscalar(arg) or arg.jac is None):
raise TypeError("bad argument")
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val
for key, val in kwargs.items()}
return argstmp, kwargstmp
def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict
argstmp, kwargstmp = self._prep_args(args, kwargs)
return Field(self._domain, ptw_dict[op][0](self._val, *argstmp, **kwargstmp))
def ptw_with_deriv(self, op, *args, **kwargs):
from .pointwise import ptw_dict
argstmp, kwargstmp = self._prep_args(args, kwargs)
tmp = ptw_dict[op][1](self._val, *argstmp, **kwargstmp)
return (Field(self._domain, tmp[0]), Field(self._domain, tmp[1]))
for op in ["__add__", "__radd__", for op in ["__add__", "__radd__",
"__sub__", "__rsub__", "__sub__", "__rsub__",
...@@ -721,11 +730,3 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", ...@@ -721,11 +730,3 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
"In-place operations are deliberately not supported") "In-place operations are deliberately not supported")
return func2 return func2
setattr(Field, op, func(op)) setattr(Field, op, func(op))
for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh",
"absolute", "sinc", "sign", "log10", "log1p", "expm1"]:
def func(f):
def func2(self):
return Field(self._domain, getattr(np, f)(self.val))
return func2
setattr(Field, f, func(f))
...@@ -126,7 +126,7 @@ class _LognormalMomentMatching(Operator): ...@@ -126,7 +126,7 @@ class _LognormalMomentMatching(Operator):
logmean, logsig = _lognormal_moments(mean, sig, N_copies) logmean, logsig = _lognormal_moments(mean, sig, N_copies)
self._mean = mean self._mean = mean
self._sig = sig self._sig = sig
op = _normal(logmean, logsig, key, N_copies).exp() op = _normal(logmean, logsig, key, N_copies).ptw("exp")
self._domain, self._target = op.domain, op.target self._domain, self._target = op.domain, op.target
self.apply = op.apply self.apply = op.apply
...@@ -224,8 +224,8 @@ class _Normalization(Operator): ...@@ -224,8 +224,8 @@ class _Normalization(Operator):
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
amp = x.exp() amp = x.ptw("exp")
spec = (2*x).exp() spec = amp**2
# FIXME This normalizes also the zeromode which is supposed to be left # FIXME This normalizes also the zeromode which is supposed to be left
# untouched by this operator # untouched by this operator
return self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp return self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp
...@@ -332,17 +332,17 @@ class _Amplitude(Operator): ...@@ -332,17 +332,17 @@ class _Amplitude(Operator):
sig_fluc = vol1 @ ps_expander @ fluctuations sig_fluc = vol1 @ ps_expander @ fluctuations
xi = ducktape(dom, None, key) xi = ducktape(dom, None, key)
sigma = sig_flex*(Adder(shift) @ sig_asp).sqrt() sigma = sig_flex*(Adder(shift) @ sig_asp).ptw("sqrt")
smooth = _SlopeRemover(target, space) @ twolog @ (sigma*xi) smooth = _SlopeRemover(target, space) @ twolog @ (sigma*xi)
op = _Normalization(target, space) @ (slope + smooth) op = _Normalization(target, space) @ (slope + smooth)
if N_copies > 0: if N_copies > 0:
op = Distributor @ op op = Distributor @ op
sig_fluc = Distributor @ sig_fluc sig_fluc = Distributor @ sig_fluc
op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op) op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.ptw("reciprocal"))*op)
self._fluc = (_Distributor(dofdex, fluctuations.target, self._fluc = (_Distributor(dofdex, fluctuations.target,
distributed_tgt[0]) @ fluctuations) distributed_tgt[0]) @ fluctuations)
else: else:
op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.one_over())*op) op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.ptw("reciprocal"))*op)
self._fluc = fluctuations self._fluc = fluctuations
self.apply = op.apply self.apply = op.apply
...@@ -527,7 +527,7 @@ class CorrelatedFieldMaker: ...@@ -527,7 +527,7 @@ class CorrelatedFieldMaker:
for _ in range(prior_info): for _ in range(prior_info):
sc.add(op(from_random('normal', op.domain))) sc.add(op(from_random('normal', op.domain)))
mean = sc.mean.val mean = sc.mean.val
stddev = sc.var.sqrt().val stddev = sc.var.ptw("sqrt").val
for m, s in zip(mean.flatten(), stddev.flatten()): for m, s in zip(mean.flatten(), stddev.flatten()):
logger.info('{}: {:.02E} ± {:.02E}'.format(kk, m, s)) logger.info('{}: {:.02E} ± {:.02E}'.format(kk, m, s))
...@@ -539,7 +539,7 @@ class CorrelatedFieldMaker: ...@@ -539,7 +539,7 @@ class CorrelatedFieldMaker:
from ..sugar import from_random from ..sugar import from_random
scm = 1. scm = 1.
for a in self._a: for a in self._a:
op = a.fluctuation_amplitude*self._azm.one_over() op = a.fluctuation_amplitude*self._azm.ptw("reciprocal")
res = np.array([op(from_random('normal', op.domain)).val res = np.array([op(from_random('normal', op.domain)).val
for _ in range(nsamples)]) for _ in range(nsamples)])
scm *= res**2 + 1. scm *= res**2 + 1.
...@@ -573,9 +573,9 @@ class CorrelatedFieldMaker: ...@@ -573,9 +573,9 @@ class CorrelatedFieldMaker:
return self.average_fluctuation(0) return self.average_fluctuation(0)
q = 1. q = 1.
for a in self._a: for a in self._a:
fl = a.fluctuation_amplitude*self._azm.one_over() fl = a.fluctuation_amplitude*self._azm.ptw("reciprocal")
q = q*(Adder(full(fl.target, 1.)) @ fl**2) q = q*(Adder(full(fl.target, 1.)) @ fl**2)
return (Adder(full(q.target, -1.)) @ q).sqrt()*self._azm return (Adder(full(q.target, -1.)) @ q).ptw("sqrt")*self._azm
def slice_fluctuation(self, space): def slice_fluctuation(self, space):
"""Returns operator which acts on prior or posterior samples""" """Returns operator which acts on prior or posterior samples"""
...@@ -587,12 +587,12 @@ class CorrelatedFieldMaker: ...@@ -587,12 +587,12 @@ class CorrelatedFieldMaker:
return self.average_fluctuation(0) return self.average_fluctuation(0)
q = 1. q = 1.
for j in range(len(self._a)): for j in range(len(self._a)):
fl = self._a[j].fluctuation_amplitude*self._azm.one_over() fl = self._a[j].fluctuation_amplitude*self._azm.ptw("reciprocal")
if j == space: if j == space:
q = q*fl**2 q = q*fl**2
else: else:
q = q*(Adder(full(fl.target, 1.)) @ fl**2) q = q*(Adder(full(fl.target, 1.)) @ fl**2)
return q.sqrt()*self._azm return q.ptw("sqrt")*self._azm
def average_fluctuation(self, space): def average_fluctuation(self, space):
"""Returns operator which acts on prior or posterior samples""" """Returns operator which acts on prior or posterior samples"""
......
...@@ -97,9 +97,9 @@ def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, c ...@@ -97,9 +97,9 @@ def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, c
m = CentralPadd.adjoint(FFTB(Sm(m))) m = CentralPadd.adjoint(FFTB(Sm(m)))
ops['smoothed_dynamics'] = m ops['smoothed_dynamics'] = m
m = -m.log() m = -m.ptw("log")
if not minimum_phase: if not minimum_phase:
m = m.exp() m = m.ptw("exp")
if causal or minimum_phase: if causal or minimum_phase:
m = Real.adjoint(FFT.inverse(Realizer(FFT.target).adjoint(m))) m = Real.adjoint(FFT.inverse(Realizer(FFT.target).adjoint(m)))
kernel = makeOp( kernel = makeOp(
...@@ -114,19 +114,19 @@ def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, c ...@@ -114,19 +114,19 @@ def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, c
c = FieldAdapter(UnstructuredDomain(len(sigc)), keys[1]) c = FieldAdapter(UnstructuredDomain(len(sigc)), keys[1])
c = makeOp(Field(c.target, np.array(sigc)))(c) c = makeOp(Field(c.target, np.array(sigc)))(c)
lightspeed = ScalingOperator(c.target, -0.5)(c).exp() lightspeed = ScalingOperator(c.target, -0.5)(c).ptw("exp")
scaling = np.array(m.target[0].distances[1:])/m.target[0].distances[0] scaling = np.array(m.target[0].distances[1:])/m.target[0].distances[0]
scaling = DiagonalOperator(Field(c.target, scaling)) scaling = DiagonalOperator(Field(c.target, scaling))
ops['lightspeed'] = scaling(lightspeed) ops['lightspeed'] = scaling(lightspeed)
c = LightConeOperator(c.target, m.target, quant) @ c.exp() c = LightConeOperator(c.target, m.target, quant) @ c.ptw("exp")
ops['light_cone'] = c ops['light_cone'] = c
m = c*m m = c*m
if causal or minimum_phase: if causal or minimum_phase:
m = FFT(Real(m)) m = FFT(Real(m))
if minimum_phase: if minimum_phase:
m = m.exp() m = m.ptw("exp")
return m, ops return m, ops
......
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..field import Field from ..field import Field
from ..linearization import Linearization
from ..operators.linear_operator import LinearOperator from ..operators.linear_operator import LinearOperator
from ..operators.operator import Operator from ..operators.operator import Operator
...@@ -132,7 +131,7 @@ class LightConeOperator(Operator): ...@@ -132,7 +131,7 @@ class LightConeOperator(Operator):
self._sigx = sigx self._sigx = sigx
def apply(self, x): def apply(self, x):
lin = isinstance(x, Linearization) lin = x.jac is not None
a, derivs = _cone_arrays(x.val.val if lin else x.val, self.target, self._sigx, lin) a, derivs = _cone_arrays(x.val.val if lin else x.val, self.target, self._sigx, lin)
res = Field(self.target, a) res = Field(self.target, a)
if not lin: if not lin:
......
...@@ -22,7 +22,6 @@ from scipy.interpolate import CubicSpline ...@@ -22,7 +22,6 @@ from scipy.interpolate import CubicSpline
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field from ..field import Field
from ..linearization import Linearization
from ..operators.operator import Operator from ..operators.operator import Operator
from ..sugar import makeOp from ..sugar import makeOp
from .. import random from .. import random
...@@ -79,7 +78,7 @@ class _InterpolationOperator(Operator): ...@@ -79,7 +78,7 @@ class _InterpolationOperator(Operator):
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
lin = isinstance(x, Linearization) lin = x.jac is not None
xval = x.val.val if lin else x.val xval = x.val.val if lin else x.val
res = self._interpolator(xval) res = self._interpolator(xval)
res = Field(self._domain, res) res = Field(self._domain, res)
...@@ -120,7 +119,7 @@ def InverseGammaOperator(domain, alpha, q, delta=1e-2): ...@@ -120,7 +119,7 @@ def InverseGammaOperator(domain, alpha, q, delta=1e-2):
Distance between sampling points for linear interpolation. Distance between sampling points for linear interpolation.
""" """
op = _InterpolationOperator(domain, lambda x: invgamma.ppf(norm._cdf(x), float(alpha)), op = _InterpolationOperator(domain, lambda x: invgamma.ppf(norm._cdf(x), float(alpha)),
-8.2, 8.2, delta, lambda x: x.log(), lambda x: x.exp()) -8.2, 8.2, delta, lambda x: x.ptw("log"), lambda x: x.ptw("exp"))
if np.isscalar(q): if np.isscalar(q):
return op.scale(q) return op.scale(q)
return makeOp(q) @ op return makeOp(q) @ op
...@@ -148,7 +147,7 @@ class UniformOperator(Operator): ...@@ -148,7 +147,7 @@ class UniformOperator(Operator):
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
lin = isinstance(x, Linearization) lin = x.jac is not None
xval = x.val.val if lin else x.val xval = x.val.val if lin else x.val
res = Field(self._target, self._scale*norm._cdf(xval) + self._loc) res = Field(self._target, self._scale*norm._cdf(xval) + self._loc)
if not lin: if not lin:
......
...@@ -17,13 +17,12 @@ ...@@ -17,13 +17,12 @@
import numpy as np import numpy as np
from .field import Field
from .multi_field import MultiField
from .sugar import makeOp from .sugar import makeOp
from . import utilities from . import utilities
from .operators.operator import Operator
class Linearization(object): class Linearization(Operator):
"""Let `A` be an operator and `x` a field. `Linearization` stores the value """Let `A` be an operator and `x` a field. `Linearization` stores the value
of the operator application (i.e. `A(x)`), the local Jacobian of the operator application (i.e. `A(x)`), the local Jacobian
(i.e. `dA(x)/dx`) and, optionally, the local metric. (i.e. `dA(x)/dx`) and, optionally, the local metric.
...@@ -64,7 +63,7 @@ class Linearization(object): ...@@ -64,7 +63,7 @@ class Linearization(object):
return Linearization(val, jac, metric, self._want_metric) return Linearization(val, jac, metric, self._want_metric)
def trivial_jac(self): def trivial_jac(self):
return Linearization.make_var(self._val, self._want_metric) return self.make_var(self._val, self._want_metric)
def prepend_jac(self, jac): def prepend_jac(self, jac):
metric = None metric = None
...@@ -101,6 +100,7 @@ class Linearization(object): ...@@ -101,6 +100,7 @@ class Linearization(object):
----- -----
Only available if target is a scalar Only available if target is a scalar
""" """
from .field import Field
return self._jac.adjoint_times(Field.scalar(1.)) return self._jac.adjoint_times(Field.scalar(1.))
@property @property
...@@ -135,18 +135,15 @@ class Linearization(object): ...@@ -135,18 +135,15 @@ class Linearization(object):
return self.new(self._val.real, self._jac.real) return self.new(self._val.real, self._jac.real)
def _myadd(self, other, neg): def _myadd(self, other, neg):
if isinstance(other, Linearization): if np.isscalar(other) or other.jac is None:
return self.new(self._val-other if neg else self._val+other,
self._jac, self._metric)
met = None met = None
if self._metric is not None and other._metric is not None: if self._metric is not None and other._metric is not None:
met = self._metric._myadd(other._metric, neg) met = self._metric._myadd(other._metric, neg)
return self.new( return self.new(
self._val.flexible_addsub(other._val, neg), self.val.flexible_addsub(other.val, neg),
self._jac._myadd(other._jac, neg), met) self.jac._myadd(other.jac, neg), met)
if isinstance(other, (int, float, complex, Field, MultiField)):
if neg:
return self.new(self._val-other, self._jac, self._metric)
else:
return self.new(self._val+other, self._jac, self._metric)
def __add__(self, other): def __add__(self, other):
return self._myadd(other, False) return self._myadd(other, False)
...@@ -161,37 +158,35 @@ class Linearization(object): ...@@ -161,37 +158,35 @@ class Linearization(object):
return (-self).__add__(other) return (-self).__add__(other)
def __truediv__(self, other): def __truediv__(self, other):
if isinstance(other, Linearization): if np.isscalar(other):
return self.__mul__(other.one_over()) return self.__mul__(1/other)
return self.__mul__(1./other) return self.__mul__(other.ptw("reciprocal"))
def __rtruediv__(self, other): def __rtruediv__(self, other):
return self.one_over().__mul__(other) return self.ptw("reciprocal").__mul__(other)
def __pow__(self, power): def __pow__(self, power):
if not np.isscalar(power): if not (np.isscalar(power) or power.jac is None):
return NotImplemented return NotImplemented
return self.new(self._val**power, return self.ptw("power", power)
makeOp(self._val**(power-1)).scale(power)(self._jac))
def __mul__(self, other): def __mul__(self, other):
from .sugar import makeOp
if isinstance(other, Linearization):
if self.target != other.target:
raise ValueError("domain mismatch")
return self.new(
self._val*other._val,
(makeOp(other._val)(self._jac))._myadd(
makeOp(self._val)(other._jac), False))
if np.isscalar(other): if np.isscalar(other):
if other == 1: if other == 1:
return self return self
met = None if self._metric is None else self._metric.scale(other) met = None if self._metric is None else self._metric.scale(other)
return self.new(self._val*other, self._jac.scale(other), met) return self.new(self._val*other, self._jac.scale(other), met)
if isinstance(other, (Field, MultiField)): from .sugar import makeOp
if other.jac is None:
if self.target != other.domain: if self.target != other.domain:
raise ValueError("domain mismatch") raise ValueError("domain mismatch")
return self.new(self._val*other, makeOp(other)(self._jac)) return self.new(self._val*other, makeOp(other)(self._jac))
if self.target != other.target:
raise ValueError("domain mismatch")
return self.new(
self.val*other.val,
(makeOp(other.val)(self.jac))._myadd(
makeOp(self.val)(other.jac), False))
def __rmul__(self, other): def __rmul__(self, other):
return self.__mul__(other) return self.__mul__(other)
...@@ -209,17 +204,16 @@ class Linearization(object): ...@@ -209,17 +204,16 @@ class Linearization(object):
Linearization Linearization
the outer product of self and other the outer product of self and other
""" """
if np.isscalar(other):
return self.__mul__(other)
from .operators.outer_product_operator import OuterProduct from .operators.outer_product_operator import OuterProduct
if isinstance(other, Linearization): if other.jac is None:
return self.new(OuterProduct(self._val, other.domain)(other),
OuterProduct(self._jac(self._val), other.domain))
return self.new( return self.new(
OuterProduct(self._val, other.target)(other._val), OuterProduct(self._val, other.target)(other._val),
OuterProduct(self._jac(self._val), other.target)._myadd( OuterProduct(self._jac(self._val), other.target)._myadd(
OuterProduct(self._val, other.target)(other._jac), False)) OuterProduct(self._val, other.target)(other._jac), False))
if np.isscalar(other):
return self.__mul__(other)
if isinstance(other, (Field, MultiField)):
return self.new(OuterProduct(self._val, other.domain)(other),
OuterProduct(self._jac(self._val), other.domain))
def vdot(self, other): def vdot(self, other):
"""Computes the inner product of this Linearization with a Field or """Computes the inner product of this Linearization with a Field or
...@@ -235,7 +229,7 @@ class Linearization(object): ...@@ -235,7 +229,7 @@ class Linearization(object):
the inner product of self and other the inner product of self and other
""" """
from .operators.simple_linear_operators import VdotOperator from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)): if other.jac is None:
return self.new( return self.new(
self._val.vdot(other), self._val.vdot(other),
VdotOperator(other)(self._jac)) VdotOperator(other)(self._jac))
...@@ -282,105 +276,10 @@ class Linearization(object): ...@@ -282,105 +276,10 @@ class Linearization(object):
self._val.integrate(spaces), self._val.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac)) ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def exp(self): def ptw(self, op, *args, **kwargs):
tmp = self._val.exp() from .pointwise import ptw_dict