Commit 9f070958 authored by Martin Reinecke's avatar Martin Reinecke

merge

parents ee45aff1 d7aaaea0
...@@ -99,6 +99,6 @@ if __name__ == '__main__': ...@@ -99,6 +99,6 @@ if __name__ == '__main__':
H, convergence = minimizer(H) H, convergence = minimizer(H)
# Plot results # Plot results
ift.plot(sky(H.position)) ift.plot(sky(H.position), title='Reconstruction')
ift.plot_finish() ift.plot(GR.adjoint(data), title='Data')
# FIXME PLOTTING ift.plot_finish(name='getting_started_2.png', xsize=16, ysize=16)
...@@ -77,7 +77,9 @@ from .plotting.plot import plot, plot_finish ...@@ -77,7 +77,9 @@ from .plotting.plot import plot, plot_finish
from .library.amplitude_model import AmplitudeModel from .library.amplitude_model import AmplitudeModel
from .library.los_response import LOSResponse from .library.los_response import LOSResponse
# from .library.point_sources import PointSources
#from .library.inverse_gamma_model import InverseGammaModel
from .library.wiener_filter_curvature import WienerFilterCurvature from .library.wiener_filter_curvature import WienerFilterCurvature
from .library.correlated_fields import CorrelatedField from .library.correlated_fields import CorrelatedField
# make_mf_correlated_field) # make_mf_correlated_field)
......
...@@ -30,28 +30,35 @@ from ..sugar import makeOp ...@@ -30,28 +30,35 @@ from ..sugar import makeOp
from ..utilities import memo from ..utilities import memo
class PointSources(Model): class InverseGammaModel(Model):
def __init__(self, position, alpha, q): def __init__(self, position, alpha, q, key):
super(PointSources, self).__init__(position) super(InverseGammaModel, self).__init__(position)
self._alpha = alpha self._alpha = alpha
self._q = q self._q = q
self._key = key
@classmethod
def make(cls, actual_position, alpha, q, key):
pos = cls.inverseIG(actual_position, alpha, q)
mf = MultiField.from_dict({key: pos})
return cls(mf, alpha, q, key)
def at(self, position): def at(self, position):
return self.__class__(position, self._alpha, self._q) return self.__class__(position, self._alpha, self._q, self._key)
@property @property
@memo @memo
def value(self): def value(self):
points = self.position['points'].local_data points = self.position[self._key].local_data
# MR FIXME?! # MR FIXME?!
points = np.clip(points, None, 8.2) points = np.clip(points, None, 8.2)
points = Field.from_local_data(self.position['points'].domain, points) points = Field.from_local_data(self.position[self._key].domain, points)
return self.IG(points, self._alpha, self._q) return self.IG(points, self._alpha, self._q)
@property @property
@memo @memo
def jacobian(self): def jacobian(self):
u = self.position['points'].local_data u = self.position[self._key].local_data
inner = norm.pdf(u) inner = norm.pdf(u)
outer_inv = invgamma.pdf(invgamma.ppf(norm.cdf(u), outer_inv = invgamma.pdf(invgamma.ppf(norm.cdf(u),
self._alpha, self._alpha,
...@@ -60,19 +67,18 @@ class PointSources(Model): ...@@ -60,19 +67,18 @@ class PointSources(Model):
# FIXME # FIXME
outer_inv = np.clip(outer_inv, 1e-20, None) outer_inv = np.clip(outer_inv, 1e-20, None)
outer = 1/outer_inv outer = 1/outer_inv
grad = Field.from_local_data(self.position['points'].domain, grad = Field.from_local_data(self.position[self._key].domain,
inner*outer) inner*outer)
grad = makeOp(MultiField.from_dict({"points": grad}, grad = makeOp(MultiField.from_dict({self._key: grad},
self.position._domain)) self.position._domain))
return SelectionOperator(grad.target, 'points')*grad return SelectionOperator(grad.target, self._key)*grad
@staticmethod @staticmethod
def IG(field, alpha, q): def IG(field, alpha, q):
foo = invgamma.ppf(norm.cdf(field.local_data), alpha, scale=q) foo = invgamma.ppf(norm.cdf(field.local_data), alpha, scale=q)
return Field.from_local_data(field.domain, foo) return Field.from_local_data(field.domain, foo)
# MR FIXME: why does this take an np.ndarray instead of a Field?
@staticmethod @staticmethod
def inverseIG(u, alpha, q): def inverseIG(u, alpha, q):
res = norm.ppf(invgamma.cdf(u, alpha, scale=q)) res = norm.ppf(invgamma.cdf(u.local_data, alpha, scale=q))
return res return Field.from_local_data(u.domain, res)
...@@ -125,8 +125,8 @@ class Model_Tests(unittest.TestCase): ...@@ -125,8 +125,8 @@ class Model_Tests(unittest.TestCase):
# {'points': S.draw_sample()}) # {'points': S.draw_sample()})
# alpha = 1.5 # alpha = 1.5
# q = 0.73 # q = 0.73
# model = ift.PointSources(pos, alpha, q) # model = ift.InverseGammaModel(pos, alpha, q, 'points')
# # FIXME All those cdfs and ppfs are not that accurate # # FIXME All those cdfs and ppfs are not very accurate
# ift.extra.check_value_gradient_consistency(model, tol=1e-5) # ift.extra.check_value_gradient_consistency(model, tol=1e-5)
# #
# @expand(product( # @expand(product(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment