Commit a40cbebf by Martin Reinecke

### cleanup

parent 6386588d
 ... ... @@ -23,18 +23,14 @@ from .field import Field __all__ = ['cos', 'sin', 'cosh', 'sinh', 'tan', 'tanh', 'arccos', 'arcsin', 'arccosh', 'arcsinh', 'arctan', 'arctanh', 'sqrt', 'exp', 'log', 'conjugate', 'clipped_exp', 'limited_exp', 'limited_exp_deriv'] 'conjugate'] def _math_helper(x, function): if isinstance(x, Field): result_val = function(x.val) result = x.copy_empty(dtype=result_val.dtype) result.val = result_val return Field(val=function(x.val)) else: result = function(np.asarray(x)) return result return function(np.asarray(x)) def cos(x): ... ... @@ -93,36 +89,6 @@ def exp(x): return _math_helper(x, np.exp) def clipped_exp(x): return _math_helper(x, lambda z: np.exp(np.minimum(200, z))) def limited_exp(x): return _math_helper(x, _limited_exp_helper) def _limited_exp_helper(x): thr = 200. mask = x>thr if np.count_nonzero(mask) == 0: return np.exp(x) result = ((1.-thr) + x)*np.exp(thr) result[~mask] = np.exp(x[~mask]) return result def limited_exp_deriv(x): return _math_helper(x, _limited_exp_deriv_helper) def _limited_exp_deriv_helper(x): thr = 200. mask = x>thr if np.count_nonzero(mask) == 0: return np.exp(x) result = np.empty_like(x) result[mask] = np.exp(thr) result[~mask] = np.exp(x[~mask]) return result def log(x, base=None): result = _math_helper(x, np.log) if base is not None: ... ...
 ... ... @@ -89,8 +89,21 @@ class Field(object): else: global_shape = reduce(lambda x, y: x + y, shape_tuple) dtype = self._infer_dtype(dtype=dtype, val=val) self._val = np.empty(global_shape,dtype=dtype) self.set_val(new_val=val, copy=copy) if isinstance(val, Field): if self.domain!=val.domain: raise ValueError("Domain mismatch") self._val = np.array(val.val,dtype=dtype,copy=copy) elif (np.isscalar(val)): self._val=np.full(global_shape,dtype=dtype,fill_value=val) elif isinstance(val, np.ndarray): if global_shape==val.shape: self._val = np.array(val,dtype=dtype,copy=copy) else: raise ValueError("Shape mismatch") elif val is None: self._val = np.empty(global_shape,dtype=dtype) else: raise TypeError("unknown source type") def _parse_domain(self, domain, val=None): if domain is None: ... ... @@ -412,7 +425,6 @@ class Field(object): # apply the rescaler to the random fields result_list[0].val *= spec.real if not real_power: result_list[1].val *= spec.imag ... ... @@ -481,7 +493,7 @@ class Field(object): if copy: self._val[()] = new_val.val else: self._val = new_val.val self._val = np.array(new_val.val,dtype=self.dtype,copy=False) elif (np.isscalar(new_val)): self._val[()]=new_val elif isinstance(new_val, np.ndarray): ... ... @@ -490,7 +502,7 @@ class Field(object): else: if self.shape!=new_val.shape: raise ValueError("Shape mismatch") self._val = new_val self._val = np.array(new_val,dtype=self.dtype,copy=False) else: raise TypeError("unknown source type") return self ... ... @@ -573,12 +585,7 @@ class Field(object): shape """ dim_tuple = tuple(sp.dim for sp in self.domain) try: return int(reduce(lambda x, y: x * y, dim_tuple)) except TypeError: return 0 return self._val.size @property def dof(self): ... ...
 from ...operators import EndomorphicOperator,\ InvertibleOperatorMixin from ...energies.memoization import memo from ...basic_arithmetics import clipped_exp from ...basic_arithmetics import exp from ...sugar import create_composed_fft_operator ... ... @@ -71,7 +71,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin, @property @memo def _expp_sspace(self): return clipped_exp(self._fft(self.position)) return exp(self._fft(self.position)) @property @memo ... ...
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!