diff --git a/nifty6/domains/rg_space.py b/nifty6/domains/rg_space.py index c27864b25564b3b309f0d2e97e5507e37bc0bbb9..4adc200c63b7137a36d3c452dbf162486abd6d18 100644 --- a/nifty6/domains/rg_space.py +++ b/nifty6/domains/rg_space.py @@ -142,8 +142,7 @@ class RGSpace(StructuredDomain): @staticmethod def _kernel(x, sigma): - from ..sugar import exp - return exp(x*x * (-2.*np.pi*np.pi*sigma*sigma)) + return (x*x * (-2.*np.pi*np.pi*sigma*sigma)).ptw("exp") def get_fft_smoothing_kernel_function(self, sigma): if (not self.harmonic): diff --git a/nifty6/field.py b/nifty6/field.py index 306dad50301323e290411e81c0e88bb518b3ad73..ad57202b6159a823a53a433e1e9fec233081ef98 100644 --- a/nifty6/field.py +++ b/nifty6/field.py @@ -20,9 +20,10 @@ import numpy as np from . import utilities from .domain_tuple import DomainTuple +from .operators.operator import Operator -class Field(object): +class Field(Operator): """The discrete representation of a continuous field over multiple spaces. Stores data arrays and carries all the needed meta-information (i.e. the @@ -634,10 +635,9 @@ class Field(object): Field The result of the operation. """ - from .sugar import sqrt if self.scalar_weight(spaces) is not None: return self._contraction_helper('std', spaces) - return sqrt(self.var(spaces)) + return self.var(spaces).ptw("sqrt") def s_std(self): """Determines the standard deviation of the Field. @@ -677,17 +677,6 @@ class Field(object): def flexible_addsub(self, other, neg): return self-other if neg else self+other - def sigmoid(self): - return 0.5*(1.+self.tanh()) - - def clip(self, min=None, max=None): - min = min.val if isinstance(min, Field) else min - max = max.val if isinstance(max, Field) else max - return Field(self._domain, np.clip(self._val, min, max)) - - def one_over(self): - return 1/self - def _binary_op(self, other, op): # if other is a field, make sure that the domains match f = getattr(self._val, op) @@ -699,6 +688,26 @@ class Field(object): return Field(self._domain, f(other)) return NotImplemented + def _prep_args(self, args, kwargs): + for arg in args + tuple(kwargs.values()): + if not (arg is None or np.isscalar(arg) or arg.jac is None): + raise TypeError("bad argument") + argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val + for arg in args) + kwargstmp = {key: val if val is None or np.isscalar(val) else val._val + for key, val in kwargs.items()} + return argstmp, kwargstmp + + def ptw(self, op, *args, **kwargs): + from .pointwise import ptw_dict + argstmp, kwargstmp = self._prep_args(args, kwargs) + return Field(self._domain, ptw_dict[op][0](self._val, *argstmp, **kwargstmp)) + + def ptw_with_deriv(self, op, *args, **kwargs): + from .pointwise import ptw_dict + argstmp, kwargstmp = self._prep_args(args, kwargs) + tmp = ptw_dict[op][1](self._val, *argstmp, **kwargstmp) + return (Field(self._domain, tmp[0]), Field(self._domain, tmp[1])) for op in ["__add__", "__radd__", "__sub__", "__rsub__", @@ -721,11 +730,3 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", "In-place operations are deliberately not supported") return func2 setattr(Field, op, func(op)) - -for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh", - "absolute", "sinc", "sign", "log10", "log1p", "expm1"]: - def func(f): - def func2(self): - return Field(self._domain, getattr(np, f)(self.val)) - return func2 - setattr(Field, f, func(f)) diff --git a/nifty6/library/correlated_fields.py b/nifty6/library/correlated_fields.py index 9da15b1e49e837c7521b2a9205295991dcc845bb..c5156c3e1836bccc3b59146327e5eaea694dc1f8 100644 --- a/nifty6/library/correlated_fields.py +++ b/nifty6/library/correlated_fields.py @@ -126,7 +126,7 @@ class _LognormalMomentMatching(Operator): logmean, logsig = _lognormal_moments(mean, sig, N_copies) self._mean = mean self._sig = sig - op = _normal(logmean, logsig, key, N_copies).exp() + op = _normal(logmean, logsig, key, N_copies).ptw("exp") self._domain, self._target = op.domain, op.target self.apply = op.apply @@ -224,8 +224,8 @@ class _Normalization(Operator): def apply(self, x): self._check_input(x) - amp = x.exp() - spec = (2*x).exp() + amp = x.ptw("exp") + spec = amp**2 # FIXME This normalizes also the zeromode which is supposed to be left # untouched by this operator return self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp @@ -332,17 +332,17 @@ class _Amplitude(Operator): sig_fluc = vol1 @ ps_expander @ fluctuations xi = ducktape(dom, None, key) - sigma = sig_flex*(Adder(shift) @ sig_asp).sqrt() + sigma = sig_flex*(Adder(shift) @ sig_asp).ptw("sqrt") smooth = _SlopeRemover(target, space) @ twolog @ (sigma*xi) op = _Normalization(target, space) @ (slope + smooth) if N_copies > 0: op = Distributor @ op sig_fluc = Distributor @ sig_fluc - op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op) + op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.ptw("reciprocal"))*op) self._fluc = (_Distributor(dofdex, fluctuations.target, distributed_tgt[0]) @ fluctuations) else: - op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.one_over())*op) + op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.ptw("reciprocal"))*op) self._fluc = fluctuations self.apply = op.apply @@ -527,7 +527,7 @@ class CorrelatedFieldMaker: for _ in range(prior_info): sc.add(op(from_random('normal', op.domain))) mean = sc.mean.val - stddev = sc.var.sqrt().val + stddev = sc.var.ptw("sqrt").val for m, s in zip(mean.flatten(), stddev.flatten()): logger.info('{}: {:.02E} ± {:.02E}'.format(kk, m, s)) @@ -539,7 +539,7 @@ class CorrelatedFieldMaker: from ..sugar import from_random scm = 1. for a in self._a: - op = a.fluctuation_amplitude*self._azm.one_over() + op = a.fluctuation_amplitude*self._azm.ptw("reciprocal") res = np.array([op(from_random('normal', op.domain)).val for _ in range(nsamples)]) scm *= res**2 + 1. @@ -573,9 +573,9 @@ class CorrelatedFieldMaker: return self.average_fluctuation(0) q = 1. for a in self._a: - fl = a.fluctuation_amplitude*self._azm.one_over() + fl = a.fluctuation_amplitude*self._azm.ptw("reciprocal") q = q*(Adder(full(fl.target, 1.)) @ fl**2) - return (Adder(full(q.target, -1.)) @ q).sqrt()*self._azm + return (Adder(full(q.target, -1.)) @ q).ptw("sqrt")*self._azm def slice_fluctuation(self, space): """Returns operator which acts on prior or posterior samples""" @@ -587,12 +587,12 @@ class CorrelatedFieldMaker: return self.average_fluctuation(0) q = 1. for j in range(len(self._a)): - fl = self._a[j].fluctuation_amplitude*self._azm.one_over() + fl = self._a[j].fluctuation_amplitude*self._azm.ptw("reciprocal") if j == space: q = q*fl**2 else: q = q*(Adder(full(fl.target, 1.)) @ fl**2) - return q.sqrt()*self._azm + return q.ptw("sqrt")*self._azm def average_fluctuation(self, space): """Returns operator which acts on prior or posterior samples""" diff --git a/nifty6/library/dynamic_operator.py b/nifty6/library/dynamic_operator.py index cef9462ab3bb8445064a8e044641f837fb54620d..947989f0633de547759f819b6a1b84c5e854fa5e 100644 --- a/nifty6/library/dynamic_operator.py +++ b/nifty6/library/dynamic_operator.py @@ -97,9 +97,9 @@ def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, c m = CentralPadd.adjoint(FFTB(Sm(m))) ops['smoothed_dynamics'] = m - m = -m.log() + m = -m.ptw("log") if not minimum_phase: - m = m.exp() + m = m.ptw("exp") if causal or minimum_phase: m = Real.adjoint(FFT.inverse(Realizer(FFT.target).adjoint(m))) kernel = makeOp( @@ -114,19 +114,19 @@ def _make_dynamic_operator(target, harmonic_padding, sm_s0, sm_x0, cone, keys, c c = FieldAdapter(UnstructuredDomain(len(sigc)), keys[1]) c = makeOp(Field(c.target, np.array(sigc)))(c) - lightspeed = ScalingOperator(c.target, -0.5)(c).exp() + lightspeed = ScalingOperator(c.target, -0.5)(c).ptw("exp") scaling = np.array(m.target[0].distances[1:])/m.target[0].distances[0] scaling = DiagonalOperator(Field(c.target, scaling)) ops['lightspeed'] = scaling(lightspeed) - c = LightConeOperator(c.target, m.target, quant) @ c.exp() + c = LightConeOperator(c.target, m.target, quant) @ c.ptw("exp") ops['light_cone'] = c m = c*m if causal or minimum_phase: m = FFT(Real(m)) if minimum_phase: - m = m.exp() + m = m.ptw("exp") return m, ops diff --git a/nifty6/library/light_cone_operator.py b/nifty6/library/light_cone_operator.py index 5f336a8ef7f65625bd0ac5e6ab7bd0897c4c0ce9..14157079031ccd7ff78d5b4adbbf3c5b042ccabb 100644 --- a/nifty6/library/light_cone_operator.py +++ b/nifty6/library/light_cone_operator.py @@ -19,7 +19,6 @@ import numpy as np from ..domain_tuple import DomainTuple from ..field import Field -from ..linearization import Linearization from ..operators.linear_operator import LinearOperator from ..operators.operator import Operator @@ -132,7 +131,7 @@ class LightConeOperator(Operator): self._sigx = sigx def apply(self, x): - lin = isinstance(x, Linearization) + lin = x.jac is not None a, derivs = _cone_arrays(x.val.val if lin else x.val, self.target, self._sigx, lin) res = Field(self.target, a) if not lin: diff --git a/nifty6/library/special_distributions.py b/nifty6/library/special_distributions.py index 5d0d4ad7a64d98695488ac163a6bc428b55df85a..c06b15beca3076792e6b31d4d31dee368a8959f2 100644 --- a/nifty6/library/special_distributions.py +++ b/nifty6/library/special_distributions.py @@ -22,7 +22,6 @@ from scipy.interpolate import CubicSpline from ..domain_tuple import DomainTuple from ..domains.unstructured_domain import UnstructuredDomain from ..field import Field -from ..linearization import Linearization from ..operators.operator import Operator from ..sugar import makeOp from .. import random @@ -79,7 +78,7 @@ class _InterpolationOperator(Operator): def apply(self, x): self._check_input(x) - lin = isinstance(x, Linearization) + lin = x.jac is not None xval = x.val.val if lin else x.val res = self._interpolator(xval) res = Field(self._domain, res) @@ -120,7 +119,7 @@ def InverseGammaOperator(domain, alpha, q, delta=1e-2): Distance between sampling points for linear interpolation. """ op = _InterpolationOperator(domain, lambda x: invgamma.ppf(norm._cdf(x), float(alpha)), - -8.2, 8.2, delta, lambda x: x.log(), lambda x: x.exp()) + -8.2, 8.2, delta, lambda x: x.ptw("log"), lambda x: x.ptw("exp")) if np.isscalar(q): return op.scale(q) return makeOp(q) @ op @@ -148,7 +147,7 @@ class UniformOperator(Operator): def apply(self, x): self._check_input(x) - lin = isinstance(x, Linearization) + lin = x.jac is not None xval = x.val.val if lin else x.val res = Field(self._target, self._scale*norm._cdf(xval) + self._loc) if not lin: diff --git a/nifty6/linearization.py b/nifty6/linearization.py index fdd23a0b46b03449e47fcd4458b2a8d3235f7882..83156aa2323218b67304b3a06b61594edb2af668 100644 --- a/nifty6/linearization.py +++ b/nifty6/linearization.py @@ -17,13 +17,12 @@ import numpy as np -from .field import Field -from .multi_field import MultiField from .sugar import makeOp from . import utilities +from .operators.operator import Operator -class Linearization(object): +class Linearization(Operator): """Let `A` be an operator and `x` a field. `Linearization` stores the value of the operator application (i.e. `A(x)`), the local Jacobian (i.e. `dA(x)/dx`) and, optionally, the local metric. @@ -64,7 +63,7 @@ class Linearization(object): return Linearization(val, jac, metric, self._want_metric) def trivial_jac(self): - return Linearization.make_var(self._val, self._want_metric) + return self.make_var(self._val, self._want_metric) def prepend_jac(self, jac): metric = None @@ -101,6 +100,7 @@ class Linearization(object): ----- Only available if target is a scalar """ + from .field import Field return self._jac.adjoint_times(Field.scalar(1.)) @property @@ -135,18 +135,15 @@ class Linearization(object): return self.new(self._val.real, self._jac.real) def _myadd(self, other, neg): - if isinstance(other, Linearization): - met = None - if self._metric is not None and other._metric is not None: - met = self._metric._myadd(other._metric, neg) - return self.new( - self._val.flexible_addsub(other._val, neg), - self._jac._myadd(other._jac, neg), met) - if isinstance(other, (int, float, complex, Field, MultiField)): - if neg: - return self.new(self._val-other, self._jac, self._metric) - else: - return self.new(self._val+other, self._jac, self._metric) + if np.isscalar(other) or other.jac is None: + return self.new(self._val-other if neg else self._val+other, + self._jac, self._metric) + met = None + if self._metric is not None and other._metric is not None: + met = self._metric._myadd(other._metric, neg) + return self.new( + self.val.flexible_addsub(other.val, neg), + self.jac._myadd(other.jac, neg), met) def __add__(self, other): return self._myadd(other, False) @@ -161,37 +158,35 @@ class Linearization(object): return (-self).__add__(other) def __truediv__(self, other): - if isinstance(other, Linearization): - return self.__mul__(other.one_over()) - return self.__mul__(1./other) + if np.isscalar(other): + return self.__mul__(1/other) + return self.__mul__(other.ptw("reciprocal")) def __rtruediv__(self, other): - return self.one_over().__mul__(other) + return self.ptw("reciprocal").__mul__(other) def __pow__(self, power): - if not np.isscalar(power): + if not (np.isscalar(power) or power.jac is None): return NotImplemented - return self.new(self._val**power, - makeOp(self._val**(power-1)).scale(power)(self._jac)) + return self.ptw("power", power) def __mul__(self, other): - from .sugar import makeOp - if isinstance(other, Linearization): - if self.target != other.target: - raise ValueError("domain mismatch") - return self.new( - self._val*other._val, - (makeOp(other._val)(self._jac))._myadd( - makeOp(self._val)(other._jac), False)) if np.isscalar(other): if other == 1: return self met = None if self._metric is None else self._metric.scale(other) return self.new(self._val*other, self._jac.scale(other), met) - if isinstance(other, (Field, MultiField)): + from .sugar import makeOp + if other.jac is None: if self.target != other.domain: raise ValueError("domain mismatch") return self.new(self._val*other, makeOp(other)(self._jac)) + if self.target != other.target: + raise ValueError("domain mismatch") + return self.new( + self.val*other.val, + (makeOp(other.val)(self.jac))._myadd( + makeOp(self.val)(other.jac), False)) def __rmul__(self, other): return self.__mul__(other) @@ -209,17 +204,16 @@ class Linearization(object): Linearization the outer product of self and other """ - from .operators.outer_product_operator import OuterProduct - if isinstance(other, Linearization): - return self.new( - OuterProduct(self._val, other.target)(other._val), - OuterProduct(self._jac(self._val), other.target)._myadd( - OuterProduct(self._val, other.target)(other._jac), False)) if np.isscalar(other): return self.__mul__(other) - if isinstance(other, (Field, MultiField)): + from .operators.outer_product_operator import OuterProduct + if other.jac is None: return self.new(OuterProduct(self._val, other.domain)(other), OuterProduct(self._jac(self._val), other.domain)) + return self.new( + OuterProduct(self._val, other.target)(other._val), + OuterProduct(self._jac(self._val), other.target)._myadd( + OuterProduct(self._val, other.target)(other._jac), False)) def vdot(self, other): """Computes the inner product of this Linearization with a Field or @@ -235,7 +229,7 @@ class Linearization(object): the inner product of self and other """ from .operators.simple_linear_operators import VdotOperator - if isinstance(other, (Field, MultiField)): + if other.jac is None: return self.new( self._val.vdot(other), VdotOperator(other)(self._jac)) @@ -282,105 +276,10 @@ class Linearization(object): self._val.integrate(spaces), ContractionOperator(self._jac.target, spaces, 1)(self._jac)) - def exp(self): - tmp = self._val.exp() - return self.new(tmp, makeOp(tmp)(self._jac)) - - def clip(self, min=None, max=None): - tmp = self._val.clip(min, max) - if (min is None) and (max is None): - return self - elif max is None: - tmp2 = makeOp(1. - (tmp == min)) - elif min is None: - tmp2 = makeOp(1. - (tmp == max)) - else: - tmp2 = makeOp(1. - (tmp == min) - (tmp == max)) - return self.new(tmp, tmp2(self._jac)) - - def sqrt(self): - tmp = self._val.sqrt() - return self.new(tmp, makeOp(0.5/tmp)(self._jac)) - - def sin(self): - tmp = self._val.sin() - tmp2 = self._val.cos() - return self.new(tmp, makeOp(tmp2)(self._jac)) - - def cos(self): - tmp = self._val.cos() - tmp2 = - self._val.sin() - return self.new(tmp, makeOp(tmp2)(self._jac)) - - def tan(self): - tmp = self._val.tan() - tmp2 = 1./(self._val.cos()**2) - return self.new(tmp, makeOp(tmp2)(self._jac)) - - def sinc(self): - tmp = self._val.sinc() - tmp2 = ((np.pi*self._val).cos()-tmp)/self._val - ind = self._val.val == 0 - loc = tmp2.val_rw() - loc[ind] = 0 - tmp2 = Field(tmp.domain, loc) - return self.new(tmp, makeOp(tmp2)(self._jac)) - - def log(self): - tmp = self._val.log() - return self.new(tmp, makeOp(1./self._val)(self._jac)) - - def log10(self): - tmp = self._val.log10() - tmp2 = 1. / (self._val * np.log(10)) - return self.new(tmp, makeOp(tmp2)(self._jac)) - - def log1p(self): - tmp = self._val.log1p() - tmp2 = 1. / (1. + self._val) - return self.new(tmp, makeOp(tmp2)(self.jac)) - - def expm1(self): - tmp = self._val.expm1() - tmp2 = self._val.exp() - return self.new(tmp, makeOp(tmp2)(self.jac)) - - def sinh(self): - tmp = self._val.sinh() - tmp2 = self._val.cosh() - return self.new(tmp, makeOp(tmp2)(self._jac)) - - def cosh(self): - tmp = self._val.cosh() - tmp2 = self._val.sinh() - return self.new(tmp, makeOp(tmp2)(self._jac)) - - def tanh(self): - tmp = self._val.tanh() - return self.new(tmp, makeOp(1.-tmp**2)(self._jac)) - - def sigmoid(self): - tmp = self._val.tanh() - tmp2 = 0.5*(1.+tmp) - return self.new(tmp2, makeOp(0.5*(1.-tmp**2))(self._jac)) - - def absolute(self): - if utilities.iscomplextype(self._val.dtype): - raise TypeError("Argument must not be complex") - tmp = self._val.absolute() - tmp2 = self._val.sign() - - ind = self._val.val == 0 - loc = tmp2.val_rw().astype(float) - loc[ind] = np.nan - tmp2 = Field(tmp.domain, loc) - - return self.new(tmp, makeOp(tmp2)(self._jac)) - - def one_over(self): - tmp = 1./self._val - tmp2 = - tmp/self._val - return self.new(tmp, makeOp(tmp2)(self._jac)) + def ptw(self, op, *args, **kwargs): + from .pointwise import ptw_dict + t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs) + return self.new(t1, makeOp(t2)(self._jac)) def add_metric(self, metric): return self.new(self._val, self._jac, metric) diff --git a/nifty6/multi_field.py b/nifty6/multi_field.py index 08b0b23c4ef5e106d67164c9809e5e827f8accda..341234904debbe83082ee2c2da965a47507e68a5 100644 --- a/nifty6/multi_field.py +++ b/nifty6/multi_field.py @@ -21,9 +21,10 @@ from . import utilities from .field import Field from .multi_domain import MultiDomain from .domain_tuple import DomainTuple +from .operators.operator import Operator -class MultiField(object): +class MultiField(Operator): def __init__(self, domain, val): """The discrete representation of a continuous field over a sum space. @@ -199,13 +200,8 @@ class MultiField(object): def conjugate(self): return self._transform(lambda x: x.conjugate()) - def clip(self, min=None, max=None): - ncomp = len(self._val) - lmin = min._val if isinstance(min, MultiField) else (min,)*ncomp - lmax = max._val if isinstance(max, MultiField) else (max,)*ncomp - return MultiField( - self._domain, - tuple(self._val[i].clip(lmin[i], lmax[i]) for i in range(ncomp))) + def clip(self, a_min=None, a_max=None): + return self.ptw("clip", a_min, a_max) def s_all(self): for v in self._val: @@ -310,8 +306,30 @@ class MultiField(object): res[key] = -val if neg else val return MultiField.from_dict(res) - def one_over(self): - return 1/self + def _prep_args(self, args, kwargs, i): + for arg in args + tuple(kwargs.values()): + if not (arg is None or np.isscalar(arg) or arg.jac is None): + raise TypeError("bad argument") + argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i] + for arg in args) + kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i] + for key, val in kwargs.items()} + return argstmp, kwargstmp + + def ptw(self, op, *args, **kwargs): + tmp = [] + for i in range(len(self._val)): + argstmp, kwargstmp = self._prep_args(args, kwargs, i) + tmp.append(self._val[i].ptw(op, *argstmp, **kwargstmp)) + return MultiField(self.domain, tuple(tmp)) + + def ptw_with_deriv(self, op, *args, **kwargs): + tmp = [] + for i in range(len(self._val)): + argstmp, kwargstmp = self._prep_args(args, kwargs, i) + tmp.append(self._val[i].ptw_with_deriv(op, *argstmp, **kwargstmp)) + return (MultiField(self.domain, tuple(v[0] for v in tmp)), + MultiField(self.domain, tuple(v[1] for v in tmp))) def _binary_op(self, other, op): f = getattr(Field, op) @@ -347,14 +365,3 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", "In-place operations are deliberately not supported") return func2 setattr(MultiField, op, func(op)) - - -for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh", - "absolute", "sinc", "sign", "log10", "log1p", "expm1"]: - def func(f): - def func2(self): - fu = getattr(Field, f) - return MultiField(self.domain, - tuple(fu(val) for val in self.values())) - return func2 - setattr(MultiField, f, func(f)) diff --git a/nifty6/operators/energy_operators.py b/nifty6/operators/energy_operators.py index 28d08b4b22c0ea315e3216658d264254f43e0dfc..d201a3d3252eb0ae9e19dc6706f72f54be231168 100644 --- a/nifty6/operators/energy_operators.py +++ b/nifty6/operators/energy_operators.py @@ -20,7 +20,6 @@ import numpy as np from .. import utilities from ..domain_tuple import DomainTuple from ..field import Field -from ..linearization import Linearization from ..multi_domain import MultiDomain from ..multi_field import MultiField from ..sugar import makeDomain, makeOp @@ -59,9 +58,8 @@ class Squared2NormOperator(EnergyOperator): def apply(self, x): self._check_input(x) - if not isinstance(x, Linearization): - res = x.vdot(x) - return res + if x.jac is None: + return x.vdot(x) res = x.val.vdot(x.val) return x.new(res, VdotOperator(2*x.val)) @@ -88,7 +86,7 @@ class QuadraticFormOperator(EnergyOperator): def apply(self, x): self._check_input(x) - if not isinstance(x, Linearization): + if x.jac is None: return 0.5*x.vdot(self._op(x)) res = 0.5*x.val.vdot(self._op(x.val)) return x.new(res, VdotOperator(self._op(x.val))) @@ -127,8 +125,8 @@ class VariableCovarianceGaussianEnergy(EnergyOperator): def apply(self, x): self._check_input(x) - res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].log().sum()) - if not isinstance(x, Linearization) or not x.want_metric: + res = 0.5*(x[self._r].vdot(x[self._r]*x[self._icov]).real - x[self._icov].ptw("log").sum()) + if not x.want_metric: return res mf = {self._r: x.val[self._icov], self._icov: .5*x.val[self._icov]**(-2)} return res.add_metric(makeOp(MultiField.from_dict(mf))) @@ -195,7 +193,7 @@ class GaussianEnergy(EnergyOperator): self._check_input(x) residual = x if self._mean is None else x - self._mean res = self._op(residual).real - if isinstance(x, Linearization) and x.want_metric: + if x.want_metric: return res.add_metric(self._met) return res @@ -229,8 +227,8 @@ class PoissonianEnergy(EnergyOperator): def apply(self, x): self._check_input(x) - res = x.sum() - x.log().vdot(self._d) - if not isinstance(x, Linearization) or not x.want_metric: + res = x.sum() - x.ptw("log").vdot(self._d) + if not x.want_metric: return res return res.add_metric(makeOp(1./x.val)) @@ -269,8 +267,8 @@ class InverseGammaLikelihood(EnergyOperator): def apply(self, x): self._check_input(x) - res = x.log().vdot(self._alphap1) + x.one_over().vdot(self._beta) - if not isinstance(x, Linearization) or not x.want_metric: + res = x.ptw("log").vdot(self._alphap1) + x.ptw("reciprocal").vdot(self._beta) + if not x.want_metric: return res return res.add_metric(makeOp(self._alphap1/(x.val**2))) @@ -298,8 +296,8 @@ class StudentTEnergy(EnergyOperator): def apply(self, x): self._check_input(x) - res = ((self._theta+1)/2)*(x**2/self._theta).log1p().sum() - if not isinstance(x, Linearization) or not x.want_metric: + res = ((self._theta+1)/2)*(x**2/self._theta).ptw("log1p").sum() + if not x.want_metric: return res met = ScalingOperator(self.domain, (self._theta+1) / (self._theta+3)) return res.add_metric(met) @@ -332,8 +330,8 @@ class BernoulliEnergy(EnergyOperator): def apply(self, x): self._check_input(x) - res = -x.log().vdot(self._d) + (1.-x).log().vdot(self._d-1.) - if not isinstance(x, Linearization) or not x.want_metric: + res = -x.ptw("log").vdot(self._d) + (1.-x).ptw("log").vdot(self._d-1.) + if not x.want_metric: return res return res.add_metric(makeOp(1./(x.val*(1. - x.val)))) @@ -382,7 +380,7 @@ class StandardHamiltonian(EnergyOperator): def apply(self, x): self._check_input(x) - if not isinstance(x, Linearization) or not x.want_metric or self._ic_samp is None: + if not x.want_metric or self._ic_samp is None: return (self._lh + self._prior)(x) lhx, prx = self._lh(x), self._prior(x) return (lhx+prx).add_metric(SamplingEnabler(lhx.metric, prx.metric, self._ic_samp)) diff --git a/nifty6/operators/linear_operator.py b/nifty6/operators/linear_operator.py index 366085a7fa2d55b29f70186a9e5e833192084c5b..f9f24b50487bbfefb88c6883b7077267c60abeee 100644 --- a/nifty6/operators/linear_operator.py +++ b/nifty6/operators/linear_operator.py @@ -171,10 +171,10 @@ class LinearOperator(Operator): def __call__(self, x): """Same as :meth:`times`""" from ..linearization import Linearization - if isinstance(x, (Field, MultiField)): - return self.apply(x, self.TIMES) - if isinstance(x, Linearization): + if x.jac is not None: return x.new(self(x._val), self).prepend_jac(x.jac) + if x.val is not None: + return self.apply(x, self.TIMES) return self@x def times(self, x): diff --git a/nifty6/operators/operator.py b/nifty6/operators/operator.py index 6ae43d1800da92ffc3951a3d66bec215b4eeafe7..ab1a42bfa868312b27d3cf9c6f33fce6e788a2f9 100644 --- a/nifty6/operators/operator.py +++ b/nifty6/operators/operator.py @@ -17,9 +17,8 @@ import numpy as np -from ..field import Field -from ..multi_field import MultiField from ..utilities import NiftyMeta, indent +from .. import pointwise class Operator(metaclass=NiftyMeta): @@ -45,9 +44,65 @@ class Operator(metaclass=NiftyMeta): ------- target : DomainTuple or MultiDomain """ - return self._target + @property + def val(self): + """The numerical value associated with this object + For "pure" operators this is `None`. For Field-like objects this + is a `numpy.ndarray` or a dictionary of `numpy.ndarray`s mathcing the + object's `target`. + + Returns + ------- + None or numpy.ndarray or dictionary of np.ndarrays : the numerical value + """ + return None + + @property + def jac(self): + """The Jacobian associated with this object + For "pure" operators this is `None`. For Field-like objects this + can be `None` (in which case the object is a constant), or it can be a + `LinearOperator` with `domain` and `target` matching the object's. + + Returns + ------- + None or LinearOperator : the Jacobian + + Notes + ----- + if `value` is None, this must be `None` as well! + """ + return None + + @property + def want_metric(self): + """Whether a metric should be computed for the full expression. + This is `False` whenever `jac` is `None`. In other cases it signals + that operators processing this object should compute the metric. + + Returns + ------- + bool : whether the metric should be computed + """ + return False + + @property + def metric(self): + """The metric associated with the object. + This is `None`, except when all the following conditions hold: + - `want_metric` is `True` + - `target` is the scalar domain + - the operator chain contained an operator which could compute the + metric + + Returns + ------- + None or LinearOperator : the metric + """ + return None + @staticmethod def _check_domain_equality(dom_op, dom_field): if dom_op != dom_field: @@ -74,15 +129,13 @@ class Operator(metaclass=NiftyMeta): return ContractionOperator(self.target, spaces)(self) def vdot(self, other): - from ..field import Field - from ..multi_field import MultiField from ..sugar import makeOp - if isinstance(other, Operator): + if not isinstance(other, Operator): + raise TypeError + if other.jac is None: res = self.conjugate()*other - elif isinstance(other, (Field, MultiField)): - res = makeOp(other) @ self.conjugate() else: - raise TypeError + res = makeOp(other) @ self.conjugate() return res.sum() @property @@ -153,14 +206,9 @@ class Operator(metaclass=NiftyMeta): return _OpSum(self, -x) def __pow__(self, power): - if not np.isscalar(power): + if not (np.isscalar(power) or power.jac is None): return NotImplemented - return _OpChain.make((_PowerOp(self.target, power), self)) - - def clip(self, min=None, max=None): - if min is None and max is None: - return self - return _OpChain.make((_Clipper(self.target, min, max), self)) + return self.ptw("power", power) def apply(self, x): """Applies the operator to a Field or MultiField. @@ -179,11 +227,10 @@ class Operator(metaclass=NiftyMeta): return self.apply(x.extract(self.domain)) def _check_input(self, x): - from ..linearization import Linearization from .scaling_operator import ScalingOperator - if not isinstance(x, (Field, MultiField, Linearization)): + if not (isinstance(x, Operator) and x.val is not None): raise TypeError - if isinstance(x, Linearization): + if x.jac is not None: if not isinstance(x.jac, ScalingOperator): raise ValueError if x.jac._factor != 1: @@ -191,12 +238,11 @@ class Operator(metaclass=NiftyMeta): self._check_domain_equality(self._domain, x.domain) def __call__(self, x): - from ..linearization import Linearization - from ..field import Field - from ..multi_field import MultiField - if isinstance(x, Linearization): + if not isinstance(x, Operator): + raise TypeError + if x.jac is not None: return self.apply(x.trivial_jac()).prepend_jac(x.jac) - elif isinstance(x, (Field, MultiField)): + elif x.val is not None: return self.apply(x) return self @ x @@ -222,13 +268,14 @@ class Operator(metaclass=NiftyMeta): def _simplify_for_constant_input_nontrivial(self, c_inp): return None, self + def ptw(self, op, *args, **kwargs): + return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self)) -for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh", - "sinc", "sigmoid", "absolute", "one_over", "log10", "log1p", "expm1"]: + +for f in pointwise.ptw_dict.keys(): def func(f): - def func2(self): - fa = _FunctionApplier(self.target, f) - return _OpChain.make((fa, self)) + def func2(self, *args, **kwargs): + return self.ptw(f, *args, **kwargs) return func2 setattr(Operator, f, func(f)) @@ -282,10 +329,9 @@ class _ConstantOperator(Operator): self._output = output def apply(self, x): - from ..linearization import Linearization from .simple_linear_operators import NullOperator self._check_input(x) - if isinstance(x, Linearization): + if x.jac is not None: return x.new(self._output, NullOperator(self._domain, self._target)) return self._output @@ -294,37 +340,16 @@ class _ConstantOperator(Operator): class _FunctionApplier(Operator): - def __init__(self, domain, funcname): + def __init__(self, domain, funcname, *args, **kwargs): from ..sugar import makeDomain self._domain = self._target = makeDomain(domain) self._funcname = funcname + self._args = args + self._kwargs = kwargs def apply(self, x): self._check_input(x) - return getattr(x, self._funcname)() - - -class _Clipper(Operator): - def __init__(self, domain, min=None, max=None): - from ..sugar import makeDomain - self._domain = self._target = makeDomain(domain) - self._min = min - self._max = max - - def apply(self, x): - self._check_input(x) - return x.clip(self._min, self._max) - - -class _PowerOp(Operator): - def __init__(self, domain, power): - from ..sugar import makeDomain - self._domain = self._target = makeDomain(domain) - self._power = power - - def apply(self, x): - self._check_input(x) - return x**self._power + return x.ptw(self._funcname, *self._args, **self._kwargs) class _CombinedOperator(Operator): @@ -395,8 +420,8 @@ class _OpProd(Operator): from ..linearization import Linearization from ..sugar import makeOp self._check_input(x) - lin = isinstance(x, Linearization) - wm = x.want_metric if lin else None + lin = x.jac is not None + wm = x.want_metric if lin else False x = x.val if lin else x v1 = x.extract(self._op1.domain) v2 = x.extract(self._op2.domain) @@ -438,7 +463,7 @@ class _OpSum(Operator): def apply(self, x): from ..linearization import Linearization self._check_input(x) - if not isinstance(x, Linearization): + if x.jac is None: v1 = x.extract(self._op1.domain) v2 = x.extract(self._op2.domain) return self._op1(v1).unite(self._op2(v2)) diff --git a/nifty6/operators/scaling_operator.py b/nifty6/operators/scaling_operator.py index 491dbed0a0e01dcd9a71d49a7ac3d15b060dbafa..b99d3380fa751f3f7bbfb8b15931efebd46ebad9 100644 --- a/nifty6/operators/scaling_operator.py +++ b/nifty6/operators/scaling_operator.py @@ -99,8 +99,7 @@ class ScalingOperator(EndomorphicOperator): def __call__(self, other): res = EndomorphicOperator.__call__(self, other) if np.isreal(self._factor) and self._factor >= 0: - from ..linearization import Linearization - if isinstance(other, Linearization): + if other.jac is not None: if other.metric is not None: from .sandwich_operator import SandwichOperator sqrt_fac = np.sqrt(self._factor) diff --git a/nifty6/pointwise.py b/nifty6/pointwise.py new file mode 100644 index 0000000000000000000000000000000000000000..198518a61bd510f7eb851ff1677c091f3f8b68f9 --- /dev/null +++ b/nifty6/pointwise.py @@ -0,0 +1,103 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2020 Max-Planck-Society +# Author: Martin Reinecke +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. + +import numpy as np + + +def _sqrt_helper(v): + tmp = np.sqrt(v) + return (tmp, 0.5/tmp) + + +def _sinc_helper(v): + tmp = np.sinc(v) + tmp2 = (np.cos(np.pi*v)-tmp)/v + return (tmp, np.where(v==0., 0, tmp2)) + + +def _expm1_helper(v): + tmp = np.expm1(v) + return (tmp, tmp+1.) + + +def _tanh_helper(v): + tmp = np.tanh(v) + return (tmp, 1.-tmp**2) + + +def _sigmoid_helper(v): + tmp = np.tanh(v) + tmp2 = 0.5+(0.5*tmp) + return (tmp2, 0.5-0.5*tmp**2) + + +def _reciprocal_helper(v): + tmp = 1./v + return (tmp, -tmp**2) + + +def _abs_helper(v): + if np.issubdtype(v.dtype, np.complexfloating): + raise TypeError("Argument must not be complex") + return (np.abs(v), np.where(v==0, np.nan, np.sign(v))) + + +def _sign_helper(v): + if np.issubdtype(v.dtype, np.complexfloating): + raise TypeError("Argument must not be complex") + return (np.sign(v), np.where(v==0, np.nan, 0)) + + +def _power_helper(v, expo): + return (np.power(v, expo), expo*np.power(v, expo-1)) + + +def _clip_helper(v, a_min, a_max): + if np.issubdtype(v.dtype, np.complexfloating): + raise TypeError("Argument must not be complex") + tmp = np.clip(v, a_min, a_max) + tmp2 = np.ones(v.shape) + if a_min is not None: + tmp2 = np.where(tmp==a_min, 0., tmp2) + if a_max is not None: + tmp2 = np.where(tmp==a_max, 0., tmp2) + return (tmp, tmp2) + + +ptw_dict = { + "sqrt": (np.sqrt, _sqrt_helper), + "sin" : (np.sin, lambda v: (np.sin(v), np.cos(v))), + "cos" : (np.cos, lambda v: (np.cos(v), -np.sin(v))), + "tan" : (np.tan, lambda v: (np.tan(v), 1./np.cos(v)**2)), + "sinc": (np.sinc, _sinc_helper), + "exp" : (np.exp, lambda v: (2*(np.exp(v),))), + "expm1" : (np.expm1, _expm1_helper), + "log" : (np.log, lambda v: (np.log(v), 1./v)), + "log10": (np.log10, lambda v: (np.log10(v), (1./np.log(10.))/v)), + "log1p": (np.log1p, lambda v: (np.log1p(v), 1./(1.+v))), + "sinh": (np.sinh, lambda v: (np.sinh(v), np.cosh(v))), + "cosh": (np.cosh, lambda v: (np.cosh(v), np.sinh(v))), + "tanh": (np.tanh, _tanh_helper), + "sigmoid": (lambda v: 0.5+(0.5*np.tanh(v)), _sigmoid_helper), + "reciprocal": (lambda v: 1./v, _reciprocal_helper), + "abs": (np.abs, _abs_helper), + "absolute": (np.abs, _abs_helper), + "sign": (np.sign, _sign_helper), + "power": (np.power, _power_helper), + "clip": (np.clip, _clip_helper), + } diff --git a/nifty6/sugar.py b/nifty6/sugar.py index 8937136727dde8b6636af6cff764769a18a4b6b1..c1afe9de2737d68709eb1004080675db646d5076 100644 --- a/nifty6/sugar.py +++ b/nifty6/sugar.py @@ -33,16 +33,15 @@ from .operators.distributors import PowerDistributor from .operators.operator import Operator from .operators.scaling_operator import ScalingOperator from .plot import Plot +from . import pointwise + __all__ = ['PS_field', 'power_analyze', 'create_power_operator', 'create_harmonic_smoothing_operator', 'from_random', 'full', 'makeField', - 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'sigmoid', - 'sin', 'cos', 'tan', 'sinh', 'cosh', 'log10', - 'absolute', 'one_over', 'clip', 'sinc', "log1p", "expm1", - 'conjugate', 'get_signal_variance', 'makeOp', 'domain_union', + 'makeDomain', 'get_signal_variance', 'makeOp', 'domain_union', 'get_default_codomain', 'single_plot', 'exec_time', - 'calculate_position'] + 'calculate_position'] + list(pointwise.ptw_dict.keys()) def PS_field(pspace, func): @@ -320,7 +319,7 @@ def makeDomain(domain): return DomainTuple.make(domain) -def makeOp(input): +def makeOp(input, dom=None): """Converts a Field or MultiField to a diagonal operator. Parameters @@ -334,12 +333,22 @@ def makeOp(input): - if MultiField, a BlockDiagonalOperator with entries given by this MultiField is returned. + dom : DomainTuple or MultiDomain + if `input` is a scalar, this is used as the operator's domain + Notes ----- No volume factors are applied. """ if input is None: return None + if np.isscalar(input): + if not isinstance(dom, (DomainTuple, MultiDomain)): + raise TypeError("need proper `dom` argument") + return SalingOperator(dom, input) + if dom is not None: + if not dom == input.domain: + raise ValueError("domain mismatch") if input.domain is DomainTuple.scalar_domain(): return ScalingOperator(input.domain, float(input.val)) if isinstance(input, Field): @@ -366,30 +375,18 @@ def domain_union(domains): return MultiDomain.union(domains) -# Arithmetic functions working on Fields - +# Pointwise functions _current_module = sys.modules[__name__] -for f in ["sqrt", "exp", "log", "log10", "tanh", "sigmoid", - "conjugate", 'sin', 'cos', 'tan', 'sinh', 'cosh', - 'absolute', 'one_over', 'sinc', 'log1p', 'expm1']: +for f in pointwise.ptw_dict.keys(): def func(f): - def func2(x): - from .linearization import Linearization - from .operators.operator import Operator - if isinstance(x, (Field, MultiField, Linearization, Operator)): - return getattr(x, f)() - else: - return getattr(np, f)(x) + def func2(x, *args, **kwargs): + return x.ptw(f, *args, **kwargs) return func2 setattr(_current_module, f, func(f)) -def clip(a, a_min=None, a_max=None): - return a.clip(a_min, a_max) - - def get_default_codomain(domainoid, space=None): """For `RGSpace`, returns the harmonic partner domain. For `DomainTuple`, returns a copy of the object in which the domain diff --git a/test/test_energy_gradients.py b/test/test_energy_gradients.py index 259fbb0f3d2cfd30f4203a1827b276027e9e632b..13eb90e8cf27877b77fb4124a8f64c2f96380b45 100644 --- a/test/test_energy_gradients.py +++ b/test/test_energy_gradients.py @@ -77,7 +77,7 @@ def test_studentt(field): def test_hamiltonian_and_KL(field): - field = field.exp() + field = field.ptw("exp") space = field.domain lh = ift.GaussianEnergy(domain=space) hamiltonian = ift.StandardHamiltonian(lh) @@ -91,7 +91,7 @@ def test_hamiltonian_and_KL(field): def test_variablecovariancegaussian(field): if isinstance(field.domain, ift.MultiDomain): return - dc = {'a': field, 'b': field.exp()} + dc = {'a': field, 'b': field.ptw("exp")} mf = ift.MultiField.from_dict(dc) energy = ift.VariableCovarianceGaussianEnergy(field.domain, 'a', 'b') ift.extra.check_jacobian_consistency(energy, mf, tol=1e-6) @@ -101,7 +101,7 @@ def test_variablecovariancegaussian(field): def test_inverse_gamma(field): if isinstance(field.domain, ift.MultiDomain): return - field = field.exp() + field = field.ptw("exp") space = field.domain d = ift.random.current_rng().normal(10, size=space.shape)**2 d = ift.Field(space, d) @@ -112,7 +112,7 @@ def test_inverse_gamma(field): def testPoissonian(field): if isinstance(field.domain, ift.MultiDomain): return - field = field.exp() + field = field.ptw("exp") space = field.domain d = ift.random.current_rng().poisson(120, size=space.shape) d = ift.Field(space, d) @@ -123,7 +123,7 @@ def testPoissonian(field): def test_bernoulli(field): if isinstance(field.domain, ift.MultiDomain): return - field = field.sigmoid() + field = field.ptw("sigmoid") space = field.domain d = ift.random.current_rng().binomial(1, 0.1, size=space.shape) d = ift.Field(space, d) diff --git a/test/test_field.py b/test/test_field.py index f8572a4ea15ebe8007f3236a857036fe881e264d..6e154cb1994652ae2924ad590502ab53e11f4a23 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -193,12 +193,12 @@ def test_empty_domain(): def test_trivialities(): s1 = ift.RGSpace((10,)) f1 = ift.Field.full(s1, 27) - assert_equal(f1.clip(min=29).val, 29.) - assert_equal(f1.clip(max=25).val, 25.) + assert_equal(f1.clip(a_min=29, a_max=50).val, 29.) + assert_equal(f1.clip(a_min=0, a_max=25).val, 25.) assert_equal(f1.val, f1.real.val) assert_equal(f1.val, (+f1).val) f1 = ift.Field.full(s1, 27. + 3j) - assert_equal(f1.one_over().val, (1./f1).val) + assert_equal(f1.ptw("reciprocal").val, (1./f1).val) assert_equal(f1.real.val, 27.) assert_equal(f1.imag.val, 3.) assert_equal(f1.s_sum(), f1.sum(0).val) @@ -336,7 +336,7 @@ def test_emptydomain(): def test_funcs(num, dom, func): num = 5 f = ift.Field.full(dom, num) - res = getattr(f, func)() + res = f.ptw(func) res2 = getattr(np, func)(num) assert_allclose(res.val, res2) diff --git a/test/test_gaussian_energy.py b/test/test_gaussian_energy.py index 94dffb05f75ba9d98fa96e46d7be75cfa11e79d3..f19f093eeb0329ff5b593dfa4f1dc3529680bdab 100644 --- a/test/test_gaussian_energy.py +++ b/test/test_gaussian_energy.py @@ -51,7 +51,7 @@ def test_gaussian_energy(space, nonlinearity, noise, seed): return 1/(1 + k**2)**dim pspec = ift.PS_field(pspace, pspec) - A = Dist(ift.sqrt(pspec)) + A = Dist(pspec.ptw("sqrt")) N = ift.ScalingOperator(space, noise) n = N.draw_sample() R = ift.ScalingOperator(space, 10.) @@ -61,7 +61,7 @@ def test_gaussian_energy(space, nonlinearity, noise, seed): return R @ ht @ ift.makeOp(A) else: tmp = ht @ ift.makeOp(A) - nonlin = getattr(tmp, nonlinearity)() + nonlin = tmp.ptw(nonlinearity) return R @ nonlin d = d_model()(xi0) + n diff --git a/test/test_linearization.py b/test/test_linearization.py index aa031055bba99b054f523c0403201eacad6620d6..e81268c5eba2a3bb7f082e3edb7d6d45a4270cfe 100644 --- a/test/test_linearization.py +++ b/test/test_linearization.py @@ -43,19 +43,19 @@ def test_special_gradients(): jt(var.clip(-1, 0), np.zeros_like(s)) assert_allclose( - _lin2grad(ift.Linearization.make_var(0*f).sinc()), np.zeros(s.shape)) - assert_(np.isnan(_lin2grad(ift.Linearization.make_var(0*f).absolute()))) + _lin2grad(ift.Linearization.make_var(0*f).ptw("sinc")), np.zeros(s.shape)) + assert_(np.isnan(_lin2grad(ift.Linearization.make_var(0*f).ptw("abs")))) assert_allclose( - _lin2grad(ift.Linearization.make_var(0*f + 10).absolute()), + _lin2grad(ift.Linearization.make_var(0*f + 10).ptw("abs")), np.ones(s.shape)) assert_allclose( - _lin2grad(ift.Linearization.make_var(0*f - 10).absolute()), + _lin2grad(ift.Linearization.make_var(0*f - 10).ptw("abs")), -np.ones(s.shape)) @pmp('f', [ 'log', 'exp', 'sqrt', 'sin', 'cos', 'tan', 'sinc', 'sinh', 'cosh', 'tanh', - 'absolute', 'one_over', 'sigmoid', 'log10', 'log1p', "expm1" + 'absolute', 'reciprocal', 'sigmoid', 'log10', 'log1p', "expm1" ]) def test_actual_gradients(f): dom = ift.UnstructuredDomain((1,)) @@ -63,8 +63,8 @@ def test_actual_gradients(f): eps = 1e-8 var0 = ift.Linearization.make_var(fld) var1 = ift.Linearization.make_var(fld + eps) - f0 = getattr(var0, f)().val.val - f1 = getattr(var1, f)().val.val + f0 = var0.ptw(f).val.val + f1 = var1.ptw(f).val.val df0 = (f1 - f0)/eps - df1 = _lin2grad(getattr(var0, f)()) + df1 = _lin2grad(var0.ptw(f)) assert_allclose(df0, df1, rtol=100*eps) diff --git a/test/test_multi_field.py b/test/test_multi_field.py index 4033ca0495bc6cfe6b2aeca92ed3f5f534147162..9469596c9cc155c81060c189de4309d0121425e1 100644 --- a/test/test_multi_field.py +++ b/test/test_multi_field.py @@ -33,7 +33,7 @@ def test_vdot(): def test_func(): f1 = ift.from_random("normal", domain=dom, dtype=np.complex128) assert_allclose( - ift.log(ift.exp((f1)))["d1"].val, f1["d1"].val) + f1.ptw("exp").ptw("log")["d1"].val, f1["d1"].val) def test_multifield_field_consistency(): diff --git a/test/test_operators/test_convolution_operators.py b/test/test_operators/test_convolution_operators.py index ccd4bfb2c5dad0d29a9b476ced52c3f74614c664..52497023c3b003d50b1b464e9db5af9bb62ee415 100644 --- a/test/test_operators/test_convolution_operators.py +++ b/test/test_operators/test_convolution_operators.py @@ -46,7 +46,7 @@ def test_gaussian_smoothing(): N = 128 sigma = N / 10**4 dom = ift.RGSpace(N) - sig = ift.exp(ift.Field.from_random('normal', dom)) + sig = ift.Field.from_random('normal', dom).ptw("exp") fco_op = ift.FuncConvolutionOperator(dom, lambda x: gauss(x, sigma)) sm_op = ift.HarmonicSmoothingOperator(dom, sigma) assert_allclose(fco_op(sig).val, diff --git a/test/test_operators/test_correlated_fields.py b/test/test_operators/test_correlated_fields.py index bb493c43a89a2fc467e17fb09f50fb26accb3472..5794deea4c402f823a66d6e6b2cac520b92781c4 100644 --- a/test/test_operators/test_correlated_fields.py +++ b/test/test_operators/test_correlated_fields.py @@ -36,7 +36,7 @@ def testAmplitudesConsistency(rseed, sspace, Astds, offset_std_mean, N, zm_mean) sc = ift.StatCalculator() for s in samples: sc.add(op(s.extract(op.domain))) - return sc.mean.val, sc.var.sqrt().val + return sc.mean.val, sc.var.ptw("sqrt").val with ift.random.Context(rseed): nsam = 100 diff --git a/test/test_operators/test_interpolated.py b/test/test_operators/test_interpolated.py index 6259252801ab810eee5a1fdba9f5256a48ece228..2c217f37b2efe32497f540eb6248cc77f1ec1d01 100644 --- a/test/test_operators/test_interpolated.py +++ b/test/test_operators/test_interpolated.py @@ -36,7 +36,7 @@ def testInterpolationAccuracy(space, seed): S = ift.ScalingOperator(space, 1.) pos = S.draw_sample() alpha = 1.5 - qs = [0.73, pos.exp().val] + qs = [0.73, pos.ptw("exp").val] for q in qs: qfld = q if not np.isscalar(q): diff --git a/test/test_operators/test_jacobian.py b/test/test_operators/test_jacobian.py index 41a91d61c69f85c38302763f3c847f77a74580dc..9c574d2ded1d7049924ad9996a3f7edf2d3ad76c 100644 --- a/test/test_operators/test_jacobian.py +++ b/test/test_operators/test_jacobian.py @@ -67,7 +67,7 @@ def testBinary(type1, type2, space, seed): model = ift.ScalingOperator(space, 2.456)(select_s1*select_s2) pos = ift.from_random("normal", dom) ift.extra.check_jacobian_consistency(model, pos, ntries=20) - model = ift.sigmoid(2.456*(select_s1*select_s2)) + model = (2.456*(select_s1*select_s2)).ptw("sigmoid") pos = ift.from_random("normal", dom) ift.extra.check_jacobian_consistency(model, pos, ntries=20) pos = ift.from_random("normal", dom) diff --git a/test/test_operators/test_partial_multifield_insert.py b/test/test_operators/test_partial_multifield_insert.py index 670cc2271186bcd6d3194765ab5f5882b43bf92d..f43810435d117e04ed478e2aada58a63e5072fb3 100644 --- a/test/test_operators/test_partial_multifield_insert.py +++ b/test/test_operators/test_partial_multifield_insert.py @@ -30,10 +30,10 @@ dtype = list2fixture([np.float64, np.float32, np.complex64, np.complex128]) def test_part_mf_insert(): dom = ift.RGSpace(3) op1 = ift.ScalingOperator(dom, 1.32).ducktape('a').ducktape_left('a1') - op2 = ift.ScalingOperator(dom, 1).exp().ducktape('b').ducktape_left('b1') - op3 = ift.ScalingOperator(dom, 1).sin().ducktape('c').ducktape_left('c1') + op2 = ift.ScalingOperator(dom, 1).ptw("exp").ducktape('b').ducktape_left('b1') + op3 = ift.ScalingOperator(dom, 1).ptw("sin").ducktape('c').ducktape_left('c1') op4 = ift.ScalingOperator(dom, 1).ducktape('c0').ducktape_left('c')**2 - op5 = ift.ScalingOperator(dom, 1).tan().ducktape('d0').ducktape_left('d') + op5 = ift.ScalingOperator(dom, 1).ptw("tan").ducktape('d0').ducktape_left('d') a = op1 + op2 + op3 b = op4 + op5 op = a.partial_insert(b) diff --git a/test/test_sugar.py b/test/test_sugar.py index 0318aef67178760facea7cdb22672fbfe1003f42..90e0f98824a6ac61f9fbe1d704913724a67a9d44 100644 --- a/test/test_sugar.py +++ b/test/test_sugar.py @@ -41,7 +41,7 @@ def test_get_signal_variance(): def test_exec_time(): dom = ift.RGSpace(12, harmonic=True) op = ift.HarmonicTransformOperator(dom) - op1 = op.exp() + op1 = op.ptw("exp") lh = ift.GaussianEnergy(domain=op.target) @ op1 ic = ift.GradientNormController(iteration_limit=2) ham = ift.StandardHamiltonian(lh, ic_samp=ic) @@ -54,7 +54,7 @@ def test_exec_time(): def test_calc_pos(): dom = ift.RGSpace(12, harmonic=True) - op = ift.HarmonicTransformOperator(dom).exp() + op = ift.HarmonicTransformOperator(dom).ptw("exp") fld = op(0.1*ift.from_random('normal', op.domain)) pos = ift.calculate_position(op, fld) ift.extra.assert_allclose(op(pos), fld, 1e-1, 1e-1)