Commit 2e434668 authored by Martin Reinecke's avatar Martin Reinecke

simplification and cosmetics

parent 5d2241a3
......@@ -78,8 +78,6 @@ from .library.amplitude_model import AmplitudeModel
from .library.inverse_gamma_model import InverseGammaModel
from .library.los_response import LOSResponse
#from .library.inverse_gamma_model import InverseGammaModel
from .library.wiener_filter_curvature import WienerFilterCurvature
from .library.correlated_fields import CorrelatedField
# make_mf_correlated_field)
......
......@@ -47,13 +47,11 @@ class Field(object):
"""
def __init__(self, domain, val):
self._uni = None
if not isinstance(domain, DomainTuple):
raise TypeError("domain must be of type DomainTuple")
if not isinstance(val, dobj.data_object):
if type(val) is not dobj.data_object:
if np.isscalar(val):
self._uni = val
val = dobj.uniform_full(domain.shape, val)
val = dobj.full(domain.shape, val)
else:
raise TypeError("val must be of type dobj.data_object")
if domain.shape != val.shape:
......@@ -394,14 +392,10 @@ class Field(object):
return self
def __neg__(self):
if self._uni is None:
return Field(self._domain, -self._val)
return Field(self._domain, -self._uni)
return Field(self._domain, -self._val)
def __abs__(self):
if self._uni is None:
return Field(self._domain, abs(self._val))
return Field(self._domain, abs(self._uni))
return Field(self._domain, abs(self._val))
def _contraction_helper(self, op, spaces):
if spaces is None:
......@@ -617,96 +611,12 @@ class Field(object):
return self + other
def positive_tanh(self):
if self._uni is None:
return 0.5*(1.+self.tanh())
return Field(self._domain, 0.5*(1.+np.tanh(self._uni)))
def __add__(self, other):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
if self._uni is None:
if other._uni is None:
return Field(self._domain, self._val+other._val)
if other._uni == 0:
return self
return Field(self._domain, self._val+other._uni)
else:
if self._uni == 0:
return other
if other._uni is None:
return Field(self._domain, other._val+self._uni)
return Field(self._domain, self._uni+other._uni)
if np.isscalar(other):
if self._uni is None:
return Field(self._domain, self._val+other)
return Field(self._domain, self._uni+other)
return NotImplemented
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
if self._uni is None:
if other._uni is None:
return Field(self._domain, self._val-other._val)
if other._uni == 0:
return self
return Field(self._domain, self._val-other._uni)
else:
if self._uni == 0:
return -other
if other._uni is None:
return Field(self._domain, self._uni-other._val)
return Field(self._domain, self._uni-other._uni)
if np.isscalar(other):
if self._uni is None:
return Field(self._domain, self._val-other)
return Field(self._domain, self._uni-other)
return NotImplemented
def __mul__(self, other):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain is not self._domain:
raise ValueError("domains are incompatible.")
if self._uni is None:
if other._uni is None:
return Field(self._domain, self._val*other._val)
if other._uni == 1:
return self
if other._uni == 0:
return other
return Field(self._domain, self._val*other._uni)
else:
if self._uni == 1:
return other
if self._uni == 0:
return self
if other._uni is None:
return Field(self._domain, other._val*self._uni)
return Field(self._domain, self._uni*other._uni)
if np.isscalar(other):
if self._uni is None:
if other == 1:
return self
if other == 0:
return Field(self._domain, other)
return Field(self._domain, self._val*other)
return Field(self._domain, self._uni*other)
return NotImplemented
for op in ["__rsub__",
"__rmul__",
return 0.5*(1.+self.tanh())
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
"__div__", "__rdiv__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
......@@ -739,11 +649,7 @@ for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
for f in ["sqrt", "exp", "log", "tanh"]:
def func(f):
def func2(self):
if self._uni is None:
fu = getattr(dobj, f)
return Field(domain=self._domain, val=fu(self.val))
else:
fu = getattr(np, f)
return Field(domain=self._domain, val=fu(self._uni))
fu = getattr(dobj, f)
return Field(domain=self._domain, val=fu(self.val))
return func2
setattr(Field, f, func(f))
......@@ -102,10 +102,10 @@ class Linearization(object):
from .operators.simple_linear_operators import VdotOperator
if isinstance(other, (Field, MultiField)):
return Linearization(
Field(DomainTuple.scalar_domain(),self._val.vdot(other)),
Field(DomainTuple.scalar_domain(), self._val.vdot(other)),
VdotOperator(other)(self._jac))
return Linearization(
Field(DomainTuple.scalar_domain(),self._val.vdot(other._val)),
Field(DomainTuple.scalar_domain(), self._val.vdot(other._val)),
VdotOperator(self._val)(other._jac) +
VdotOperator(other._val)(self._jac))
......
......@@ -26,12 +26,12 @@ from .iteration_controller import IterationController
from .minimizer import Minimizer
def _toNdarray(fld):
def _toArray(fld):
return fld.to_global_data().reshape(-1)
def _toFlatNdarray(fld):
return fld.val.flatten()
def _toArray_rw(fld):
return fld.to_global_data_rw().reshape(-1)
def _toField(arr, dom):
......@@ -54,12 +54,12 @@ class _MinHelper(object):
def jac(self, x):
self._update(x)
return _toFlatNdarray(self._energy.gradient)
return _toArray_rw(self._energy.gradient)
def hessp(self, x, p):
self._update(x)
res = self._energy.metric(_toField(p, self._domain))
return _toFlatNdarray(res)
return _toArray_rw(res)
class ScipyMinimizer(Minimizer):
......@@ -95,7 +95,7 @@ class ScipyMinimizer(Minimizer):
else:
raise ValueError("unrecognized bounds")
x = hlp._energy.position.val.flatten()
x = _toArray_rw(hlp._energy.position)
hessp = hlp.hessp if self._need_hessp else None
r = opt.minimize(hlp.fun, x, method=self._method, jac=hlp.jac,
hessp=hessp, options=self._options, bounds=bounds)
......@@ -147,11 +147,11 @@ class ScipyCG(Minimizer):
self._op = op
def __call__(self, inp):
return _toNdarray(self._op(_toField(inp, self._op.domain)))
return _toArray(self._op(_toField(inp, self._op.domain)))
op = energy._A
b = _toNdarray(energy._b)
sx = _toNdarray(energy.position)
b = _toArray(energy._b)
sx = _toArray(energy.position)
sci_op = scipy_linop(shape=(op.domain.size, op.target.size),
matvec=mymatvec(op))
prec_op = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment