From 6bbeea6f1961c2da237be7d6d88c9a7a91214907 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Tue, 21 Aug 2018 17:29:47 +0200 Subject: [PATCH] fixes --- demos/bernoulli_demo.py | 2 +- demos/getting_started_2.py | 2 +- demos/polynomial_fit.py | 4 ++-- nifty5/minimization/energy_adapter.py | 4 +++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/demos/bernoulli_demo.py b/demos/bernoulli_demo.py index f81d4ad7f..a6fd2292e 100644 --- a/demos/bernoulli_demo.py +++ b/demos/bernoulli_demo.py @@ -73,7 +73,7 @@ if __name__ == '__main__': # Minimize the Hamiltonian H = ift.Hamiltonian(likelihood, ic_sampling) - H = ift.EnergyAdapter(position, H) + H = ift.EnergyAdapter(position, H, want_metric=True) # minimizer = ift.L_BFGS(ic_newton) H, convergence = minimizer(H) diff --git a/demos/getting_started_2.py b/demos/getting_started_2.py index ad0ff69ea..065616da9 100644 --- a/demos/getting_started_2.py +++ b/demos/getting_started_2.py @@ -93,7 +93,7 @@ if __name__ == '__main__': # Minimize the Hamiltonian H = ift.Hamiltonian(likelihood) - H = ift.EnergyAdapter(position, H) + H = ift.EnergyAdapter(position, H, want_metric=True) H, convergence = minimizer(H) # Plot results diff --git a/demos/polynomial_fit.py b/demos/polynomial_fit.py index 8dab1814a..403269c79 100644 --- a/demos/polynomial_fit.py +++ b/demos/polynomial_fit.py @@ -87,14 +87,14 @@ N = ift.DiagonalOperator(ift.from_global_data(d_space, var)) IC = ift.GradientNormController(tol_abs_gradnorm=1e-8) likelihood = ift.GaussianEnergy(d, N)(R) Ham = ift.Hamiltonian(likelihood, IC) -H = ift.EnergyAdapter(params, Ham) +H = ift.EnergyAdapter(params, Ham, want_metric=True) # Minimize minimizer = ift.NewtonCG(IC) H, _ = minimizer(H) # Draw posterior samples -metric = Ham(ift.Linearization.make_var(H.position)).metric +metric = Ham(ift.Linearization.make_var(H.position, want_metric=True)).metric samples = [metric.draw_sample(from_inverse=True) + H.position for _ in range(N_samples)] diff --git a/nifty5/minimization/energy_adapter.py b/nifty5/minimization/energy_adapter.py index d0e71a96f..b4627602a 100644 --- a/nifty5/minimization/energy_adapter.py +++ b/nifty5/minimization/energy_adapter.py @@ -12,6 +12,7 @@ class EnergyAdapter(Energy): super(EnergyAdapter, self).__init__(position) self._op = op self._constants = constants + self._want_metric = want_metric if len(self._constants) == 0: tmp = self._op(Linearization.make_var(self._position, want_metric)) else: @@ -25,7 +26,8 @@ class EnergyAdapter(Energy): self._metric = tmp._metric def at(self, position): - return EnergyAdapter(position, self._op, self._constants) + return EnergyAdapter(position, self._op, self._constants, + self._want_metric) @property def value(self): -- GitLab