From 2e434668ccc50a1b67e9de0acadb11f85b11b64c Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Wed, 8 Aug 2018 16:59:23 +0200 Subject: [PATCH] simplification and cosmetics --- nifty5/__init__.py | 2 - nifty5/field.py | 118 +++---------------------- nifty5/linearization.py | 4 +- nifty5/minimization/scipy_minimizer.py | 18 ++-- 4 files changed, 23 insertions(+), 119 deletions(-) diff --git a/nifty5/__init__.py b/nifty5/__init__.py index e54fc6b0f..6d30fb855 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -78,8 +78,6 @@ from .library.amplitude_model import AmplitudeModel from .library.inverse_gamma_model import InverseGammaModel from .library.los_response import LOSResponse -#from .library.inverse_gamma_model import InverseGammaModel - from .library.wiener_filter_curvature import WienerFilterCurvature from .library.correlated_fields import CorrelatedField # make_mf_correlated_field) diff --git a/nifty5/field.py b/nifty5/field.py index ba6d5787a..fbedb1e90 100644 --- a/nifty5/field.py +++ b/nifty5/field.py @@ -47,13 +47,11 @@ class Field(object): """ def __init__(self, domain, val): - self._uni = None if not isinstance(domain, DomainTuple): raise TypeError("domain must be of type DomainTuple") - if not isinstance(val, dobj.data_object): + if type(val) is not dobj.data_object: if np.isscalar(val): - self._uni = val - val = dobj.uniform_full(domain.shape, val) + val = dobj.full(domain.shape, val) else: raise TypeError("val must be of type dobj.data_object") if domain.shape != val.shape: @@ -394,14 +392,10 @@ class Field(object): return self def __neg__(self): - if self._uni is None: - return Field(self._domain, -self._val) - return Field(self._domain, -self._uni) + return Field(self._domain, -self._val) def __abs__(self): - if self._uni is None: - return Field(self._domain, abs(self._val)) - return Field(self._domain, abs(self._uni)) + return Field(self._domain, abs(self._val)) def _contraction_helper(self, op, spaces): if spaces is None: @@ -617,96 +611,12 @@ class Field(object): return self + other def positive_tanh(self): - if self._uni is None: - return 0.5*(1.+self.tanh()) - return Field(self._domain, 0.5*(1.+np.tanh(self._uni))) - - def __add__(self, other): - # if other is a field, make sure that the domains match - if isinstance(other, Field): - if other._domain is not self._domain: - raise ValueError("domains are incompatible.") - if self._uni is None: - if other._uni is None: - return Field(self._domain, self._val+other._val) - if other._uni == 0: - return self - return Field(self._domain, self._val+other._uni) - else: - if self._uni == 0: - return other - if other._uni is None: - return Field(self._domain, other._val+self._uni) - return Field(self._domain, self._uni+other._uni) - - if np.isscalar(other): - if self._uni is None: - return Field(self._domain, self._val+other) - return Field(self._domain, self._uni+other) - return NotImplemented - - def __radd__(self, other): - return self.__add__(other) - - def __sub__(self, other): - # if other is a field, make sure that the domains match - if isinstance(other, Field): - if other._domain is not self._domain: - raise ValueError("domains are incompatible.") - if self._uni is None: - if other._uni is None: - return Field(self._domain, self._val-other._val) - if other._uni == 0: - return self - return Field(self._domain, self._val-other._uni) - else: - if self._uni == 0: - return -other - if other._uni is None: - return Field(self._domain, self._uni-other._val) - return Field(self._domain, self._uni-other._uni) - - if np.isscalar(other): - if self._uni is None: - return Field(self._domain, self._val-other) - return Field(self._domain, self._uni-other) - return NotImplemented - - def __mul__(self, other): - # if other is a field, make sure that the domains match - if isinstance(other, Field): - if other._domain is not self._domain: - raise ValueError("domains are incompatible.") - if self._uni is None: - if other._uni is None: - return Field(self._domain, self._val*other._val) - if other._uni == 1: - return self - if other._uni == 0: - return other - return Field(self._domain, self._val*other._uni) - else: - if self._uni == 1: - return other - if self._uni == 0: - return self - if other._uni is None: - return Field(self._domain, other._val*self._uni) - return Field(self._domain, self._uni*other._uni) - - if np.isscalar(other): - if self._uni is None: - if other == 1: - return self - if other == 0: - return Field(self._domain, other) - return Field(self._domain, self._val*other) - return Field(self._domain, self._uni*other) - return NotImplemented - - -for op in ["__rsub__", - "__rmul__", + return 0.5*(1.+self.tanh()) + + +for op in ["__add__", "__radd__", + "__sub__", "__rsub__", + "__mul__", "__rmul__", "__div__", "__rdiv__", "__truediv__", "__rtruediv__", "__floordiv__", "__rfloordiv__", @@ -739,11 +649,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__", for f in ["sqrt", "exp", "log", "tanh"]: def func(f): def func2(self): - if self._uni is None: - fu = getattr(dobj, f) - return Field(domain=self._domain, val=fu(self.val)) - else: - fu = getattr(np, f) - return Field(domain=self._domain, val=fu(self._uni)) + fu = getattr(dobj, f) + return Field(domain=self._domain, val=fu(self.val)) return func2 setattr(Field, f, func(f)) diff --git a/nifty5/linearization.py b/nifty5/linearization.py index 88d7b4f15..d3394e6d9 100644 --- a/nifty5/linearization.py +++ b/nifty5/linearization.py @@ -102,10 +102,10 @@ class Linearization(object): from .operators.simple_linear_operators import VdotOperator if isinstance(other, (Field, MultiField)): return Linearization( - Field(DomainTuple.scalar_domain(),self._val.vdot(other)), + Field(DomainTuple.scalar_domain(), self._val.vdot(other)), VdotOperator(other)(self._jac)) return Linearization( - Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)), + Field(DomainTuple.scalar_domain(), self._val.vdot(other._val)), VdotOperator(self._val)(other._jac) + VdotOperator(other._val)(self._jac)) diff --git a/nifty5/minimization/scipy_minimizer.py b/nifty5/minimization/scipy_minimizer.py index 7cc3830d2..9fad98daa 100644 --- a/nifty5/minimization/scipy_minimizer.py +++ b/nifty5/minimization/scipy_minimizer.py @@ -26,12 +26,12 @@ from .iteration_controller import IterationController from .minimizer import Minimizer -def _toNdarray(fld): +def _toArray(fld): return fld.to_global_data().reshape(-1) -def _toFlatNdarray(fld): - return fld.val.flatten() +def _toArray_rw(fld): + return fld.to_global_data_rw().reshape(-1) def _toField(arr, dom): @@ -54,12 +54,12 @@ class _MinHelper(object): def jac(self, x): self._update(x) - return _toFlatNdarray(self._energy.gradient) + return _toArray_rw(self._energy.gradient) def hessp(self, x, p): self._update(x) res = self._energy.metric(_toField(p, self._domain)) - return _toFlatNdarray(res) + return _toArray_rw(res) class ScipyMinimizer(Minimizer): @@ -95,7 +95,7 @@ class ScipyMinimizer(Minimizer): else: raise ValueError("unrecognized bounds") - x = hlp._energy.position.val.flatten() + x = _toArray_rw(hlp._energy.position) hessp = hlp.hessp if self._need_hessp else None r = opt.minimize(hlp.fun, x, method=self._method, jac=hlp.jac, hessp=hessp, options=self._options, bounds=bounds) @@ -147,11 +147,11 @@ class ScipyCG(Minimizer): self._op = op def __call__(self, inp): - return _toNdarray(self._op(_toField(inp, self._op.domain))) + return _toArray(self._op(_toField(inp, self._op.domain))) op = energy._A - b = _toNdarray(energy._b) - sx = _toNdarray(energy.position) + b = _toArray(energy._b) + sx = _toArray(energy.position) sci_op = scipy_linop(shape=(op.domain.size, op.target.size), matvec=mymatvec(op)) prec_op = None -- GitLab