diff --git a/demos/getting_started_2.py b/demos/getting_started_2.py index c6622395e03737738e24a4702468f62ac6e0c9e8..bcceb54cf4af34080b995a3f431e147c6da2726b 100644 --- a/demos/getting_started_2.py +++ b/demos/getting_started_2.py @@ -99,6 +99,6 @@ if __name__ == '__main__': H, convergence = minimizer(H) # Plot results - ift.plot(sky(H.position)) - ift.plot_finish() - # FIXME PLOTTING + ift.plot(sky(H.position), title='Reconstruction') + ift.plot(GR.adjoint(data), title='Data') + ift.plot_finish(name='getting_started_2.png', xsize=16, ysize=16) diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 40965a95e6f9e782311b2d8d4c8c2f66ba3458d4..d978f2b683a8eb1baabda13f3af55c32724569f4 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -77,7 +77,9 @@ from .plotting.plot import plot, plot_finish from .library.amplitude_model import AmplitudeModel from .library.los_response import LOSResponse -# from .library.point_sources import PointSources + +#from .library.inverse_gamma_model import InverseGammaModel + from .library.wiener_filter_curvature import WienerFilterCurvature from .library.correlated_fields import CorrelatedField # make_mf_correlated_field) diff --git a/nifty5/library/point_sources.py b/nifty5/library/inverse_gamma_model.py similarity index 69% rename from nifty5/library/point_sources.py rename to nifty5/library/inverse_gamma_model.py index 381ec3a9697479718a41b5978b908ee1c8e287e1..aed7d5c324d09a7ec009ef8777c99b41c266b864 100644 --- a/nifty5/library/point_sources.py +++ b/nifty5/library/inverse_gamma_model.py @@ -30,28 +30,35 @@ from ..sugar import makeOp from ..utilities import memo -class PointSources(Model): - def __init__(self, position, alpha, q): - super(PointSources, self).__init__(position) +class InverseGammaModel(Model): + def __init__(self, position, alpha, q, key): + super(InverseGammaModel, self).__init__(position) self._alpha = alpha self._q = q + self._key = key + + @classmethod + def make(cls, actual_position, alpha, q, key): + pos = cls.inverseIG(actual_position, alpha, q) + mf = MultiField.from_dict({key: pos}) + return cls(mf, alpha, q, key) def at(self, position): - return self.__class__(position, self._alpha, self._q) + return self.__class__(position, self._alpha, self._q, self._key) @property @memo def value(self): - points = self.position['points'].local_data + points = self.position[self._key].local_data # MR FIXME?! points = np.clip(points, None, 8.2) - points = Field.from_local_data(self.position['points'].domain, points) + points = Field.from_local_data(self.position[self._key].domain, points) return self.IG(points, self._alpha, self._q) @property @memo def jacobian(self): - u = self.position['points'].local_data + u = self.position[self._key].local_data inner = norm.pdf(u) outer_inv = invgamma.pdf(invgamma.ppf(norm.cdf(u), self._alpha, @@ -60,19 +67,18 @@ class PointSources(Model): # FIXME outer_inv = np.clip(outer_inv, 1e-20, None) outer = 1/outer_inv - grad = Field.from_local_data(self.position['points'].domain, + grad = Field.from_local_data(self.position[self._key].domain, inner*outer) - grad = makeOp(MultiField.from_dict({"points": grad}, + grad = makeOp(MultiField.from_dict({self._key: grad}, self.position._domain)) - return SelectionOperator(grad.target, 'points')*grad + return SelectionOperator(grad.target, self._key)*grad @staticmethod def IG(field, alpha, q): foo = invgamma.ppf(norm.cdf(field.local_data), alpha, scale=q) return Field.from_local_data(field.domain, foo) - # MR FIXME: why does this take an np.ndarray instead of a Field? @staticmethod def inverseIG(u, alpha, q): - res = norm.ppf(invgamma.cdf(u, alpha, scale=q)) - return res + res = norm.ppf(invgamma.cdf(u.local_data, alpha, scale=q)) + return Field.from_local_data(u.domain, res) diff --git a/test/test_models/test_model_gradients.py b/test/test_models/test_model_gradients.py index 4849857f939552a58098530f3e9fd56b952ad1fa..11c7dfd418dfde65f8339b4ac3ec1febac85e10f 100644 --- a/test/test_models/test_model_gradients.py +++ b/test/test_models/test_model_gradients.py @@ -125,8 +125,8 @@ class Model_Tests(unittest.TestCase): # {'points': S.draw_sample()}) # alpha = 1.5 # q = 0.73 -# model = ift.PointSources(pos, alpha, q) -# # FIXME All those cdfs and ppfs are not that accurate +# model = ift.InverseGammaModel(pos, alpha, q, 'points') +# # FIXME All those cdfs and ppfs are not very accurate # ift.extra.check_value_gradient_consistency(model, tol=1e-5) # # @expand(product(