Skip to content
Snippets Groups Projects
Commit a40cbebf authored by Martin Reinecke's avatar Martin Reinecke
Browse files

cleanup

parent 6386588d
No related branches found
No related tags found
No related merge requests found
...@@ -23,18 +23,14 @@ from .field import Field ...@@ -23,18 +23,14 @@ from .field import Field
__all__ = ['cos', 'sin', 'cosh', 'sinh', 'tan', 'tanh', 'arccos', 'arcsin', __all__ = ['cos', 'sin', 'cosh', 'sinh', 'tan', 'tanh', 'arccos', 'arcsin',
'arccosh', 'arcsinh', 'arctan', 'arctanh', 'sqrt', 'exp', 'log', 'arccosh', 'arcsinh', 'arctan', 'arctanh', 'sqrt', 'exp', 'log',
'conjugate', 'clipped_exp', 'limited_exp', 'limited_exp_deriv'] 'conjugate']
def _math_helper(x, function): def _math_helper(x, function):
if isinstance(x, Field): if isinstance(x, Field):
result_val = function(x.val) return Field(val=function(x.val))
result = x.copy_empty(dtype=result_val.dtype)
result.val = result_val
else: else:
result = function(np.asarray(x)) return function(np.asarray(x))
return result
def cos(x): def cos(x):
...@@ -93,36 +89,6 @@ def exp(x): ...@@ -93,36 +89,6 @@ def exp(x):
return _math_helper(x, np.exp) return _math_helper(x, np.exp)
def clipped_exp(x):
return _math_helper(x, lambda z: np.exp(np.minimum(200, z)))
def limited_exp(x):
return _math_helper(x, _limited_exp_helper)
def _limited_exp_helper(x):
thr = 200.
mask = x>thr
if np.count_nonzero(mask) == 0:
return np.exp(x)
result = ((1.-thr) + x)*np.exp(thr)
result[~mask] = np.exp(x[~mask])
return result
def limited_exp_deriv(x):
return _math_helper(x, _limited_exp_deriv_helper)
def _limited_exp_deriv_helper(x):
thr = 200.
mask = x>thr
if np.count_nonzero(mask) == 0:
return np.exp(x)
result = np.empty_like(x)
result[mask] = np.exp(thr)
result[~mask] = np.exp(x[~mask])
return result
def log(x, base=None): def log(x, base=None):
result = _math_helper(x, np.log) result = _math_helper(x, np.log)
if base is not None: if base is not None:
......
...@@ -89,8 +89,21 @@ class Field(object): ...@@ -89,8 +89,21 @@ class Field(object):
else: else:
global_shape = reduce(lambda x, y: x + y, shape_tuple) global_shape = reduce(lambda x, y: x + y, shape_tuple)
dtype = self._infer_dtype(dtype=dtype, val=val) dtype = self._infer_dtype(dtype=dtype, val=val)
if isinstance(val, Field):
if self.domain!=val.domain:
raise ValueError("Domain mismatch")
self._val = np.array(val.val,dtype=dtype,copy=copy)
elif (np.isscalar(val)):
self._val=np.full(global_shape,dtype=dtype,fill_value=val)
elif isinstance(val, np.ndarray):
if global_shape==val.shape:
self._val = np.array(val,dtype=dtype,copy=copy)
else:
raise ValueError("Shape mismatch")
elif val is None:
self._val = np.empty(global_shape,dtype=dtype) self._val = np.empty(global_shape,dtype=dtype)
self.set_val(new_val=val, copy=copy) else:
raise TypeError("unknown source type")
def _parse_domain(self, domain, val=None): def _parse_domain(self, domain, val=None):
if domain is None: if domain is None:
...@@ -412,7 +425,6 @@ class Field(object): ...@@ -412,7 +425,6 @@ class Field(object):
# apply the rescaler to the random fields # apply the rescaler to the random fields
result_list[0].val *= spec.real result_list[0].val *= spec.real
if not real_power: if not real_power:
result_list[1].val *= spec.imag result_list[1].val *= spec.imag
...@@ -481,7 +493,7 @@ class Field(object): ...@@ -481,7 +493,7 @@ class Field(object):
if copy: if copy:
self._val[()] = new_val.val self._val[()] = new_val.val
else: else:
self._val = new_val.val self._val = np.array(new_val.val,dtype=self.dtype,copy=False)
elif (np.isscalar(new_val)): elif (np.isscalar(new_val)):
self._val[()]=new_val self._val[()]=new_val
elif isinstance(new_val, np.ndarray): elif isinstance(new_val, np.ndarray):
...@@ -490,7 +502,7 @@ class Field(object): ...@@ -490,7 +502,7 @@ class Field(object):
else: else:
if self.shape!=new_val.shape: if self.shape!=new_val.shape:
raise ValueError("Shape mismatch") raise ValueError("Shape mismatch")
self._val = new_val self._val = np.array(new_val,dtype=self.dtype,copy=False)
else: else:
raise TypeError("unknown source type") raise TypeError("unknown source type")
return self return self
...@@ -573,12 +585,7 @@ class Field(object): ...@@ -573,12 +585,7 @@ class Field(object):
shape shape
""" """
return self._val.size
dim_tuple = tuple(sp.dim for sp in self.domain)
try:
return int(reduce(lambda x, y: x * y, dim_tuple))
except TypeError:
return 0
@property @property
def dof(self): def dof(self):
......
from ...operators import EndomorphicOperator,\ from ...operators import EndomorphicOperator,\
InvertibleOperatorMixin InvertibleOperatorMixin
from ...energies.memoization import memo from ...energies.memoization import memo
from ...basic_arithmetics import clipped_exp from ...basic_arithmetics import exp
from ...sugar import create_composed_fft_operator from ...sugar import create_composed_fft_operator
...@@ -71,7 +71,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, ...@@ -71,7 +71,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
@property @property
@memo @memo
def _expp_sspace(self): def _expp_sspace(self):
return clipped_exp(self._fft(self.position)) return exp(self._fft(self.position))
@property @property
@memo @memo
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment