Commit f430d35b authored by Philipp Arras's avatar Philipp Arras
Browse files

Restructure InverseGammaOperator as Interpolation Operator

parent 128309f6
Pipeline #70449 passed with stages
in 15 minutes and 34 seconds
......@@ -69,7 +69,7 @@ from .minimization.metric_gaussian_kl import MetricGaussianKL
from .sugar import *
from .plot import Plot
from .library.inverse_gamma_operator import InverseGammaOperator
from .library.special_distributions import InverseGammaOperator
from .library.los_response import LOSResponse
from .library.dynamic_operator import (dynamic_operator,
dynamic_lightcone_operator)
......
......@@ -25,7 +25,35 @@ from ..operators.operator import Operator
from ..sugar import makeOp
class InverseGammaOperator(Operator):
class _InterpolationOperator(Operator):
def __init__(self, domain, func, xmin, xmax, delta, table_func=None, inverse_table_func=None):
self._domain = self._target = DomainTuple.make(domain)
self._xmin, self._xmax = float(xmin), float(xmax)
self._d = float(delta)
self._xs = np.arange(xmin, xmax+2*self._d, self._d)
if table_func is not None:
foo = func(np.random.randn(10))
np.testing.assert_allclose(foo, inverse_table_func(table_func(foo)))
self._table = table_func(func(self._xs))
self._deriv = (self._table[1:]-self._table[:-1]) / self._d
self._inv_table_func = inverse_table_func
def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization)
val = x.val.val if lin else x.val
val = (np.clip(val, self._xmin, self._xmax) - self._xmin) / self._d
fi = np.floor(val).astype(int)
w = val - fi
res = self._inv_table_func((1-w)*self._table[fi] + w*self._table[fi+1])
resfld = Field(self._domain, res)
if not lin:
return resfld
jac = makeOp(Field(self._domain, self._deriv[fi]*res)) @ x.jac
return x.new(resfld, jac)
def InverseGammaOperator(domain, alpha, q, delta=0.001):
"""Transforms a Gaussian into an inverse gamma distribution.
The pdf of the inverse gamma distribution is defined as follows:
......@@ -53,54 +81,6 @@ class InverseGammaOperator(Operator):
delta : float
distance between sampling points for linear interpolation.
"""
def __init__(self, domain, alpha, q, delta=0.001):
self._domain = self._target = DomainTuple.make(domain)
self._alpha, self._q, self._delta = \
float(alpha), float(q), float(delta)
self._xmin, self._xmax = -8.2, 8.2
# Precompute
xs = np.arange(self._xmin, self._xmax+2*delta, delta)
self._table = np.log(invgamma.ppf(norm.cdf(xs), self._alpha,
scale=self._q))
self._deriv = (self._table[1:]-self._table[:-1]) / delta
def apply(self, x):
self._check_input(x)
lin = isinstance(x, Linearization)
val = x.val.val if lin else x.val
val = (np.clip(val, self._xmin, self._xmax) - self._xmin) / self._delta
# Operator
fi = np.floor(val).astype(int)
w = val - fi
res = np.exp((1 - w)*self._table[fi] + w*self._table[fi + 1])
points = Field(self._domain, res)
if not lin:
return points
# Derivative of linear interpolation
der = self._deriv[fi]*res
jac = makeOp(Field(self._domain, der))
jac = jac(x.jac)
return x.new(points, jac)
@staticmethod
def IG(field, alpha, q):
foo = invgamma.ppf(norm.cdf(field.val), alpha, scale=q)
return Field(field.domain, foo)
@staticmethod
def inverseIG(u, alpha, q):
res = norm.ppf(invgamma.cdf(u.val, alpha, scale=q))
return Field(u.domain, res)
@property
def alpha(self):
return self._alpha
@property
def q(self):
return self._q
return _InterpolationOperator(domain,
lambda x: invgamma.ppf(norm.cdf(x), float(alpha)),
-8.2, 8.2, delta, np.log, np.exp).scale(q)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import pytest
from numpy.testing import assert_allclose
from scipy.stats import invgamma, norm
import nifty6 as ift
from ..common import list2fixture
pmp = pytest.mark.parametrize
pmp = pytest.mark.parametrize
space = list2fixture([ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)])
seed = list2fixture([4, 78, 23])
def testInverseGammaAccuracy(space, seed):
S = ift.ScalingOperator(space, 1.)
pos = S.draw_sample()
alpha = 1.5
q = 0.73
op = ift.InverseGammaOperator(space, alpha, q)
arr1 = op(pos).val
arr0 = q*invgamma.ppf(norm.cdf(pos.val), alpha)
assert_allclose(arr0, arr1)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment