Commit a40cbebf authored by Martin Reinecke's avatar Martin Reinecke

cleanup

parent 6386588d
......@@ -23,18 +23,14 @@ from .field import Field
__all__ = ['cos', 'sin', 'cosh', 'sinh', 'tan', 'tanh', 'arccos', 'arcsin',
'arccosh', 'arcsinh', 'arctan', 'arctanh', 'sqrt', 'exp', 'log',
'conjugate', 'clipped_exp', 'limited_exp', 'limited_exp_deriv']
'conjugate']
def _math_helper(x, function):
if isinstance(x, Field):
result_val = function(x.val)
result = x.copy_empty(dtype=result_val.dtype)
result.val = result_val
return Field(val=function(x.val))
else:
result = function(np.asarray(x))
return result
return function(np.asarray(x))
def cos(x):
......@@ -93,36 +89,6 @@ def exp(x):
return _math_helper(x, np.exp)
def clipped_exp(x):
return _math_helper(x, lambda z: np.exp(np.minimum(200, z)))
def limited_exp(x):
return _math_helper(x, _limited_exp_helper)
def _limited_exp_helper(x):
thr = 200.
mask = x>thr
if np.count_nonzero(mask) == 0:
return np.exp(x)
result = ((1.-thr) + x)*np.exp(thr)
result[~mask] = np.exp(x[~mask])
return result
def limited_exp_deriv(x):
return _math_helper(x, _limited_exp_deriv_helper)
def _limited_exp_deriv_helper(x):
thr = 200.
mask = x>thr
if np.count_nonzero(mask) == 0:
return np.exp(x)
result = np.empty_like(x)
result[mask] = np.exp(thr)
result[~mask] = np.exp(x[~mask])
return result
def log(x, base=None):
result = _math_helper(x, np.log)
if base is not None:
......
......@@ -89,8 +89,21 @@ class Field(object):
else:
global_shape = reduce(lambda x, y: x + y, shape_tuple)
dtype = self._infer_dtype(dtype=dtype, val=val)
self._val = np.empty(global_shape,dtype=dtype)
self.set_val(new_val=val, copy=copy)
if isinstance(val, Field):
if self.domain!=val.domain:
raise ValueError("Domain mismatch")
self._val = np.array(val.val,dtype=dtype,copy=copy)
elif (np.isscalar(val)):
self._val=np.full(global_shape,dtype=dtype,fill_value=val)
elif isinstance(val, np.ndarray):
if global_shape==val.shape:
self._val = np.array(val,dtype=dtype,copy=copy)
else:
raise ValueError("Shape mismatch")
elif val is None:
self._val = np.empty(global_shape,dtype=dtype)
else:
raise TypeError("unknown source type")
def _parse_domain(self, domain, val=None):
if domain is None:
......@@ -412,7 +425,6 @@ class Field(object):
# apply the rescaler to the random fields
result_list[0].val *= spec.real
if not real_power:
result_list[1].val *= spec.imag
......@@ -481,7 +493,7 @@ class Field(object):
if copy:
self._val[()] = new_val.val
else:
self._val = new_val.val
self._val = np.array(new_val.val,dtype=self.dtype,copy=False)
elif (np.isscalar(new_val)):
self._val[()]=new_val
elif isinstance(new_val, np.ndarray):
......@@ -490,7 +502,7 @@ class Field(object):
else:
if self.shape!=new_val.shape:
raise ValueError("Shape mismatch")
self._val = new_val
self._val = np.array(new_val,dtype=self.dtype,copy=False)
else:
raise TypeError("unknown source type")
return self
......@@ -573,12 +585,7 @@ class Field(object):
shape
"""
dim_tuple = tuple(sp.dim for sp in self.domain)
try:
return int(reduce(lambda x, y: x * y, dim_tuple))
except TypeError:
return 0
return self._val.size
@property
def dof(self):
......
from ...operators import EndomorphicOperator,\
InvertibleOperatorMixin
from ...energies.memoization import memo
from ...basic_arithmetics import clipped_exp
from ...basic_arithmetics import exp
from ...sugar import create_composed_fft_operator
......@@ -71,7 +71,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
@property
@memo
def _expp_sspace(self):
return clipped_exp(self._fft(self.position))
return exp(self._fft(self.position))
@property
@memo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment