Commit 380b5273 authored by Philipp Arras's avatar Philipp Arras
Browse files

Restructuring

parent e627438f
Pipeline #70652 passed with stages
in 17 minutes and 55 seconds
......@@ -20,12 +20,18 @@ from scipy.stats import invgamma, norm
from .. import Adder
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..linearization import Linearization
from ..operators.operator import Operator
from ..sugar import makeOp
def _f_on_np(f, arr):
fld = Field.from_raw(UnstructuredDomain(arr.shape), arr)
return f(fld).val
class _InterpolationOperator(Operator):
"""
Calculates a function pointwise on a field by interpolation.
......@@ -38,44 +44,34 @@ class _InterpolationOperator(Operator):
The domain on which the field shall be defined. This is at the same
time the domain and the target of the operator.
func : function
The function which is applied on the field.
The function which is applied on the field. Assumed to act on numpy
arrays.
xmin : float
The smallest value for which func will be evaluated.
xmax : float
The largest value for which func will be evaluated.
delta : float
Distance between sampling points for linear interpolation.
table_func : {'None', 'exp', 'log', 'power'}
exponent : float
This is only used by the 'power' table_func.
table_func : function
Non-linear function applied to table in order to transform the table
to a more linear space. Assumed to act on `Linearization`s, optional.
inv_table_func : function
Inverse of table_func, optional.
"""
def __init__(self, domain, func, xmin, xmax, delta, table_func=None, exponent=None):
def __init__(self, domain, func, xmin, xmax, delta, table_func=None, inv_table_func=None):
self._domain = self._target = DomainTuple.make(domain)
self._xmin, self._xmax = float(xmin), float(xmax)
self._d = float(delta)
self._xs = np.arange(xmin, xmax+2*self._d, self._d)
self._table = func(self._xs)
self._transform = table_func is not None
self._args = []
if exponent is not None and table_func != 'power':
raise Exception("exponent is only used when table_func is 'power'.")
if table_func is None:
pass
elif table_func == 'exp':
self._table = np.exp(self._table)
self._inv_table_func = 'log'
elif table_func == 'log':
self._table = np.log(self._table)
self._inv_table_func = 'exp'
elif table_func == 'power':
if not np.isscalar(exponent):
return NotImplemented
self._table = np.power(self._table, exponent)
self._inv_table_func = '__pow__'
self._args = [np.power(float(exponent), -1)]
else:
return NotImplemented
if table_func is not None:
if inv_table_func is None:
raise ValueError
a = func(np.random.randn(10))
a1 = _f_on_np(lambda x: inv_table_func(table_func(x)), a)
np.testing.assert_allclose(a, a1)
self._table = _f_on_np(table_func, self._table)
self._inv_table_func = inv_table_func
self._deriv = (self._table[1:] - self._table[:-1]) / self._d
def apply(self, x):
......@@ -86,16 +82,12 @@ class _InterpolationOperator(Operator):
fi = np.floor(val).astype(int)
w = val - fi
res = (1-w)*self._table[fi] + w*self._table[fi+1]
resfld = Field(self._domain, res)
if not lin:
if self._transform:
resfld = getattr(resfld, self._inv_table_func)(*self._args)
return resfld
lin = Linearization.make_var(resfld)
if self._transform:
lin = getattr(lin, self._inv_table_func)(*self._args)
jac = makeOp(Field(self._domain, self._deriv[fi]))
return x.new(lin.val, lin.jac @ jac)
res = Field(self._domain, res)
if lin:
res = x.new(res, makeOp(Field(self._domain, self._deriv[fi])))
if self._inv_table_func is not None:
res = self._inv_table_func(res)
return res
def InverseGammaOperator(domain, alpha, q, delta=0.001):
......@@ -127,8 +119,8 @@ def InverseGammaOperator(domain, alpha, q, delta=0.001):
delta : float
Distance between sampling points for linear interpolation.
"""
func = lambda x: invgamma.ppf(norm.cdf(x), float(alpha))
op = _InterpolationOperator(domain, func, -8.2, 8.2, delta, 'log')
op = _InterpolationOperator(domain, lambda x: invgamma.ppf(norm.cdf(x), float(alpha)),
-8.2, 8.2, delta, lambda x: x.log(), lambda x: x.exp())
if np.isscalar(q):
return op.scale(q)
return makeOp(q) @ op
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment