Commit 0417187d authored by Martin Reinecke's avatar Martin Reinecke

port InverseGammaModel

parent 13d16030
......@@ -76,6 +76,7 @@ from .sugar import *
from .plotting.plot import plot, plot_finish
from .library.amplitude_model import AmplitudeModel
from .library.inverse_gamma_model import InverseGammaModel
from .library.los_response import LOSResponse
#from .library.inverse_gamma_model import InverseGammaModel
......
......@@ -22,63 +22,52 @@ import numpy as np
from scipy.stats import invgamma, norm
from ..compat import *
from ..operators.operator import Operator
from ..linearization import Linearization
from ..field import Field
from ..models.model import Model
from ..multi_field import MultiField
from ..operators.selection_operator import SelectionOperator
from ..sugar import makeOp
from ..utilities import memo
class InverseGammaModel(Model):
def __init__(self, position, alpha, q, key):
super(InverseGammaModel, self).__init__(position)
class InverseGammaModel(Operator):
def __init__(self, domain, alpha, q):
self._domain = domain
self._alpha = alpha
self._q = q
self._key = key
@classmethod
def make(cls, actual_position, alpha, q, key):
pos = cls.inverseIG(actual_position, alpha, q)
mf = MultiField.from_dict({key: pos})
return cls(mf, alpha, q, key)
def at(self, position):
return self.__class__(position, self._alpha, self._q, self._key)
@property
@memo
def value(self):
points = self.position[self._key].local_data
# MR FIXME?!
points = np.clip(points, None, 8.2)
points = Field.from_local_data(self.position[self._key].domain, points)
return self.IG(points, self._alpha, self._q)
def domain(self):
return self._domain
@property
@memo
def jacobian(self):
u = self.position[self._key].local_data
inner = norm.pdf(u)
outer_inv = invgamma.pdf(invgamma.ppf(norm.cdf(u),
def target(self):
return self._domain
def apply(self, x):
lin = isinstance(x, Linearization)
val = x.val.local_data if lin else x.local_data
# MR FIXME?!
points = np.clip(val, None, 8.2)
points = self.IG(points)
points = Field.from_local_data(self._domain, points)
if not lin:
return points
inner = norm.pdf(val)
outer_inv = invgamma.pdf(invgamma.ppf(norm.cdf(val),
self._alpha,
scale=self._q),
self._alpha, scale=self._q)
# FIXME
outer_inv = np.clip(outer_inv, 1e-20, None)
outer = 1/outer_inv
grad = Field.from_local_data(self.position[self._key].domain,
inner*outer)
grad = makeOp(MultiField.from_dict({self._key: grad},
self.position._domain))
return SelectionOperator(grad.target, self._key)*grad
jac = makeOp(Field.from_local_data(self._domain, inner*outer))
jac = jac(x.jac)
return Linearization(points, jac)
@staticmethod
def IG(field, alpha, q):
foo = invgamma.ppf(norm.cdf(field.local_data), alpha, scale=q)
return Field.from_local_data(field.domain, foo)
def IG(self, field):
return invgamma.ppf(norm.cdf(field), self._alpha, scale=self._q)
@staticmethod
def inverseIG(u, alpha, q):
res = norm.ppf(invgamma.cdf(u.local_data, alpha, scale=q))
return Field.from_local_data(u.domain, res)
# MR FIXME: Do we need this?
# def inverseIG(self, u):
# return Field.from_local_data(
# u.domain, norm.ppf(invgamma.cdf(u.local_data, self._alpha,
# scale=self._q)))
......@@ -113,22 +113,20 @@ class Model_Tests(unittest.TestCase):
pos = S.draw_sample()
ift.extra.check_value_gradient_consistency(model2, pos)
# @expand(product(
# [ift.GLSpace(15),
# ift.RGSpace(64, distances=.789),
# ift.RGSpace([32, 32], distances=.789)],
# [4, 78, 23]))
# def testPointModel(seld, space, seed):
#
# S = ift.ScalingOperator(1., space)
# pos = ift.MultiField.from_dict(
# {'points': S.draw_sample()})
# alpha = 1.5
# q = 0.73
# model = ift.InverseGammaModel(pos, alpha, q, 'points')
# # FIXME All those cdfs and ppfs are not very accurate
# ift.extra.check_value_gradient_consistency(model, tol=1e-5)
#
@expand(product(
[ift.GLSpace(15),
ift.RGSpace(64, distances=.789),
ift.RGSpace([32, 32], distances=.789)],
[4, 78, 23]))
def testPointModel(self, space, seed):
S = ift.ScalingOperator(1., space)
pos = S.draw_sample()
alpha = 1.5
q = 0.73
model = ift.InverseGammaModel(space, alpha, q)
# FIXME All those cdfs and ppfs are not very accurate
ift.extra.check_value_gradient_consistency(model, pos, tol=1e-2)
# @expand(product(
# ['Variable', 'Constant'],
# [ift.GLSpace(15),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment