Commit d7b80547 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

move EnergyAdapter

parent 007fb8dc
...@@ -20,29 +20,6 @@ import nifty5 as ift ...@@ -20,29 +20,6 @@ import nifty5 as ift
import numpy as np import numpy as np
class EnergyAdapter(ift.Energy):
def __init__(self, position, op):
super(EnergyAdapter, self).__init__(position)
self._op = op
pvar = ift.Linearization.make_var(position)
self._res = op(pvar)
def at(self, position):
return EnergyAdapter(position, self._op)
@property
def value(self):
return self._res.val.local_data[()]
@property
def gradient(self):
return self._res.gradient
@property
def metric(self):
return self._res.metric
def get_2D_exposure(): def get_2D_exposure():
x_shape, y_shape = position_space.shape x_shape, y_shape = position_space.shape
...@@ -120,7 +97,7 @@ if __name__ == '__main__': ...@@ -120,7 +97,7 @@ if __name__ == '__main__':
# Minimize the Hamiltonian # Minimize the Hamiltonian
H = ift.Hamiltonian(likelihood) H = ift.Hamiltonian(likelihood)
H = EnergyAdapter(position, H) H = ift.EnergyAdapter(position, H)
#ift.extra.check_value_gradient_consistency(H) #ift.extra.check_value_gradient_consistency(H)
H = H.make_invertible(ic_cg) H = H.make_invertible(ic_cg)
H, convergence = minimizer(H) H, convergence = minimizer(H)
......
...@@ -25,28 +25,6 @@ def get_random_LOS(n_los): ...@@ -25,28 +25,6 @@ def get_random_LOS(n_los):
ends = list(np.random.uniform(0, 1, (n_los, 2)).T) ends = list(np.random.uniform(0, 1, (n_los, 2)).T)
return starts, ends return starts, ends
class EnergyAdapter(ift.Energy):
def __init__(self, position, op):
super(EnergyAdapter, self).__init__(position)
self._op = op
pvar = ift.Linearization.make_var(position)
self._res = op(pvar)
def at(self, position):
return EnergyAdapter(position, self._op)
@property
def value(self):
return self._res.val.local_data[()]
@property
def gradient(self):
return self._res.gradient
@property
def metric(self):
return self._res.metric
if __name__ == '__main__': if __name__ == '__main__':
# FIXME description of the tutorial # FIXME description of the tutorial
np.random.seed(42) np.random.seed(42)
...@@ -114,7 +92,7 @@ if __name__ == '__main__': ...@@ -114,7 +92,7 @@ if __name__ == '__main__':
for _ in range(N_samples)] for _ in range(N_samples)]
KL = ift.SampledKullbachLeiblerDivergence(H, samples) KL = ift.SampledKullbachLeiblerDivergence(H, samples)
KL = EnergyAdapter(position, KL) KL = ift.EnergyAdapter(position, KL)
KL = KL.make_invertible(ic_cg) KL = KL.make_invertible(ic_cg)
KL, convergence = minimizer(KL) KL, convergence = minimizer(KL)
position = KL.position position = KL.position
......
...@@ -103,6 +103,7 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator ...@@ -103,6 +103,7 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator
from .energies.kl import SampledKullbachLeiblerDivergence from .energies.kl import SampledKullbachLeiblerDivergence
from .energies.hamiltonian import Hamiltonian from .energies.hamiltonian import Hamiltonian
from .energies.energy_adapter import EnergyAdapter
from .operator import Operator from .operator import Operator
from .linearization import Linearization from .linearization import Linearization
......
from __future__ import absolute_import, division, print_function
from ..compat import *
from ..minimization.energy import Energy
from ..linearization import Linearization
class EnergyAdapter(Energy):
def __init__(self, position, op):
super(EnergyAdapter, self).__init__(position)
self._op = op
pvar = Linearization.make_var(position)
self._res = op(pvar)
def at(self, position):
return EnergyAdapter(position, self._op)
@property
def value(self):
return self._res.val.local_data[()]
@property
def gradient(self):
return self._res.gradient
@property
def metric(self):
return self._res.metric
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment