Commit f227a461 authored by Reimar H Leike's avatar Reimar H Leike
Browse files

refactored InverseGammaOperator be a function call to a more general InterpolatingOperator

parent c14eb740
......@@ -39,6 +39,7 @@ from .operators.sandwich_operator import SandwichOperator
from .operators.scaling_operator import ScalingOperator
from .operators.block_diagonal_operator import BlockDiagonalOperator
from .operators.outer_product_operator import OuterProduct
from .operators.interpolating_operator import InterpolatingOperator
from .operators.simple_linear_operators import (
VdotOperator, ConjugationOperator, Realizer,
FieldAdapter, ducktape, GeometryRemover, NullOperator,
......
......@@ -18,14 +18,10 @@
import numpy as np
from scipy.stats import invgamma, norm
from ..domain_tuple import DomainTuple
from ..field import Field
from ..linearization import Linearization
from ..operators.operator import Operator
from ..sugar import makeOp
from ..operators.interpolating_operator import InterpolatingOperator
class InverseGammaOperator(Operator):
def InverseGammaOperator(domain, alpha, q, delta=0.001):
"""Transforms a Gaussian into an inverse gamma distribution.
The pdf of the inverse gamma distribution is defined as follows:
......@@ -53,54 +49,5 @@ class InverseGammaOperator(Operator):
delta : float
distance between sampling points for linear interpolation.
"""
def __init__(self, domain, alpha, q, delta=0.001):
self._domain = self._target = DomainTuple.make(domain)
self._alpha, self._q, self._delta = \
float(alpha), float(q), float(delta)
self._xmin, self._xmax = -8.2, 8.2
# Precompute
xs = np.arange(self._xmin, self._xmax+2*delta, delta)
self._table = np.log(invgamma.ppf(norm.cdf(xs), self._alpha,
scale=self._q))
self._deriv = (self._table[1:]-self._table[:-1]) / delta
def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization)
val = x.val.local_data if lin else x.local_data
val = (np.clip(val, self._xmin, self._xmax) - self._xmin) / self._delta
# Operator
fi = np.floor(val).astype(int)
w = val - fi
res = np.exp((1 - w)*self._table[fi] + w*self._table[fi + 1])
points = Field.from_local_data(self._domain, res)
if not lin:
return points
# Derivative of linear interpolation
der = self._deriv[fi]*res
jac = makeOp(Field.from_local_data(self._domain, der))
jac = jac(x.jac)
return x.new(points, jac)
@staticmethod
def IG(field, alpha, q):
foo = invgamma.ppf(norm.cdf(field.local_data), alpha, scale=q)
return Field.from_local_data(field.domain, foo)
@staticmethod
def inverseIG(u, alpha, q):
res = norm.ppf(invgamma.cdf(u.local_data, alpha, scale=q))
return Field.from_local_data(u.domain, res)
@property
def alpha(self):
return self._alpha
@property
def q(self):
return self._q
func = lambda x: np.log(invgamma.ppf(norm.cdf(x), alpha))
return InterpolatingOperator(domain, func, (-8.2,8.2), delta)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..domain_tuple import DomainTuple
from ..field import Field
from ..linearization import Linearization
from ..operators.operator import Operator
from ..sugar import makeOp
class InterpolatingOperator(Operator):
"""Represents an arbitrary local nonlinearity by linear interpolation
Given any function, this operator computes a first order interpolation
on a grid. This can speed up computation for time intense local nonlinearities.
Parameters
----------
domain : Domain, tuple of Domain or DomainTuple
The domain on which the operator shall be defined. This is at the same
time the domain and the target of the operator.
f : function
The function to interpolate
bounds : Tuple of two real number
The begin and end of the interval on which the function gets interpolated
delta : float, optional
distance between sampling points for linear interpolation.
if no delta is given, 1000 points are assumed
"""
def __init__(self, domain, f, bounds, delta=None):
self._domain = self._target = DomainTuple.make(domain)
self._delta = float(delta)
self._xmin, self._xmax = bounds
# Precompute
xs = np.arange(self._xmin, self._xmax+2*delta, delta)
try:
self._table = f(xs)
except:
self._table = np.zeros_like(xs)
for i,x in enumerate(xs):
self._table[i] = f(x)
self._deriv = (self._table[1:]-self._table[:-1]) / delta
def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization)
val = x.val.local_data if lin else x.local_data
val = (np.clip(val, self._xmin, self._xmax) - self._xmin) / self._delta
# Operator
fi = np.floor(val).astype(int)
w = val - fi
res = np.exp((1 - w)*self._table[fi] + w*self._table[fi + 1])
points = Field.from_local_data(self._domain, res)
if not lin:
return points
# Derivative of linear interpolation
der = self._deriv[fi]*res
jac = makeOp(Field.from_local_data(self._domain, der))
jac = jac(x.jac)
return x.new(points, jac)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment