diff --git a/demos/getting_started_2.py b/demos/getting_started_2.py index f2267bcff4ed6adbb946f0541d7d370dca509282..298aff1576ce34c07e4d5bb822642e57f85a1b3a 100644 --- a/demos/getting_started_2.py +++ b/demos/getting_started_2.py @@ -20,29 +20,6 @@ import nifty5 as ift import numpy as np -class EnergyAdapter(ift.Energy): - def __init__(self, position, op): - super(EnergyAdapter, self).__init__(position) - self._op = op - pvar = ift.Linearization.make_var(position) - self._res = op(pvar) - - def at(self, position): - return EnergyAdapter(position, self._op) - - @property - def value(self): - return self._res.val.local_data[()] - - @property - def gradient(self): - return self._res.gradient - - @property - def metric(self): - return self._res.metric - - def get_2D_exposure(): x_shape, y_shape = position_space.shape @@ -120,7 +97,7 @@ if __name__ == '__main__': # Minimize the Hamiltonian H = ift.Hamiltonian(likelihood) - H = EnergyAdapter(position, H) + H = ift.EnergyAdapter(position, H) #ift.extra.check_value_gradient_consistency(H) H = H.make_invertible(ic_cg) H, convergence = minimizer(H) diff --git a/demos/getting_started_3b.py b/demos/getting_started_3b.py index 2a7ce7d27f3c29d5f61fad49635fd9bac2b7c33d..6cd4054c51372395772316a36d041410fd91a578 100644 --- a/demos/getting_started_3b.py +++ b/demos/getting_started_3b.py @@ -25,28 +25,6 @@ def get_random_LOS(n_los): ends = list(np.random.uniform(0, 1, (n_los, 2)).T) return starts, ends -class EnergyAdapter(ift.Energy): - def __init__(self, position, op): - super(EnergyAdapter, self).__init__(position) - self._op = op - pvar = ift.Linearization.make_var(position) - self._res = op(pvar) - - def at(self, position): - return EnergyAdapter(position, self._op) - - @property - def value(self): - return self._res.val.local_data[()] - - @property - def gradient(self): - return self._res.gradient - - @property - def metric(self): - return self._res.metric - if __name__ == '__main__': # FIXME description of the tutorial np.random.seed(42) @@ -114,7 +92,7 @@ if __name__ == '__main__': for _ in range(N_samples)] KL = ift.SampledKullbachLeiblerDivergence(H, samples) - KL = EnergyAdapter(position, KL) + KL = ift.EnergyAdapter(position, KL) KL = KL.make_invertible(ic_cg) KL, convergence = minimizer(KL) position = KL.position diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 66bc0179b0f01fd718d6981073871322c26ad0b8..705f12439978b98e5c08c14d61c9c11c449d18a5 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -103,6 +103,7 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator from .energies.kl import SampledKullbachLeiblerDivergence from .energies.hamiltonian import Hamiltonian +from .energies.energy_adapter import EnergyAdapter from .operator import Operator from .linearization import Linearization diff --git a/nifty5/energies/energy_adapter.py b/nifty5/energies/energy_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..61256371673ff1078e9f93b035fe34564327b17f --- /dev/null +++ b/nifty5/energies/energy_adapter.py @@ -0,0 +1,28 @@ +from __future__ import absolute_import, division, print_function + +from ..compat import * +from ..minimization.energy import Energy +from ..linearization import Linearization + + +class EnergyAdapter(Energy): + def __init__(self, position, op): + super(EnergyAdapter, self).__init__(position) + self._op = op + pvar = Linearization.make_var(position) + self._res = op(pvar) + + def at(self, position): + return EnergyAdapter(position, self._op) + + @property + def value(self): + return self._res.val.local_data[()] + + @property + def gradient(self): + return self._res.gradient + + @property + def metric(self): + return self._res.metric