From bb47136888bd9a056d6ad13e6c1ac03c94c73e9a Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Mon, 30 Jul 2018 16:41:29 +0200 Subject: [PATCH] PointSources -> InverseGammaModel --- nifty5/__init__.py | 2 +- ...oint_sources.py => inverse_gamma_model.py} | 21 ++++++++++--------- test/test_models/test_model_gradients.py | 4 ++-- 3 files changed, 14 insertions(+), 13 deletions(-) rename nifty5/library/{point_sources.py => inverse_gamma_model.py} (79%) diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 6c69b44bf..129735eaf 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -82,7 +82,7 @@ from .plotting.plot import plot, plot_finish from .library.amplitude_model import make_amplitude_model from .library.gaussian_energy import GaussianEnergy from .library.los_response import LOSResponse -from .library.point_sources import PointSources +from .library.inverse_gamma_model import InverseGammaModel from .library.poissonian_energy import PoissonianEnergy from .library.wiener_filter_curvature import WienerFilterCurvature from .library.correlated_fields import (make_correlated_field, diff --git a/nifty5/library/point_sources.py b/nifty5/library/inverse_gamma_model.py similarity index 79% rename from nifty5/library/point_sources.py rename to nifty5/library/inverse_gamma_model.py index fa121f675..0cf4ce362 100644 --- a/nifty5/library/point_sources.py +++ b/nifty5/library/inverse_gamma_model.py @@ -30,28 +30,29 @@ from ..sugar import makeOp from ..utilities import memo -class PointSources(Model): - def __init__(self, position, alpha, q): - super(PointSources, self).__init__(position) +class InverseGammaModel(Model): + def __init__(self, position, alpha, q, key): + super(InverseGammaModel, self).__init__(position) self._alpha = alpha self._q = q + self._key = key def at(self, position): - return self.__class__(position, self._alpha, self._q) + return self.__class__(position, self._alpha, self._q, self._key) @property @memo def value(self): - points = self.position['points'].local_data + points = self.position[self._key].local_data # MR FIXME?! points = np.clip(points, None, 8.2) - points = Field.from_local_data(self.position['points'].domain, points) + points = Field.from_local_data(self.position[self._key].domain, points) return self.IG(points, self._alpha, self._q) @property @memo def jacobian(self): - u = self.position['points'].local_data + u = self.position[self._key].local_data inner = norm.pdf(u) outer_inv = invgamma.pdf(invgamma.ppf(norm.cdf(u), self._alpha, @@ -60,11 +61,11 @@ class PointSources(Model): # FIXME outer_inv = np.clip(outer_inv, 1e-20, None) outer = 1/outer_inv - grad = Field.from_local_data(self.position['points'].domain, + grad = Field.from_local_data(self.position[self._key].domain, inner*outer) - grad = makeOp(MultiField.from_dict({"points": grad}, + grad = makeOp(MultiField.from_dict({self._key: grad}, self.position._domain)) - return SelectionOperator(grad.target, 'points')*grad + return SelectionOperator(grad.target, self._key)*grad @staticmethod def IG(field, alpha, q): diff --git a/test/test_models/test_model_gradients.py b/test/test_models/test_model_gradients.py index 9cab81df1..8821f6bcb 100644 --- a/test/test_models/test_model_gradients.py +++ b/test/test_models/test_model_gradients.py @@ -151,8 +151,8 @@ class Model_Tests(unittest.TestCase): {'points': S.draw_sample()}) alpha = 1.5 q = 0.73 - model = ift.PointSources(pos, alpha, q) - # FIXME All those cdfs and ppfs are not that accurate + model = ift.InverseGammaModel(pos, alpha, q) + # FIXME All those cdfs and ppfs are not very accurate ift.extra.check_value_gradient_consistency(model, tol=1e-5) @expand(product( -- GitLab