Commit 610a2f7c authored by Martin Reinecke's avatar Martin Reinecke
Browse files

polishing

parent 279e18f7
...@@ -41,25 +41,37 @@ class GaussianEnergy2(ift.Operator): ...@@ -41,25 +41,37 @@ class GaussianEnergy2(ift.Operator):
return ift.Linearization(res.val, res.jac, metric) return ift.Linearization(res.val, res.jac, metric)
class PoissonianEnergy2(ift.Operator): class PoissonianEnergy2(ift.Operator):
def __init__(self, d): def __init__(self, op, d):
super(PoissonianEnergy2, self).__init__() super(PoissonianEnergy2, self).__init__()
self._op = op
self._d = d self._d = d
def __call__(self, x): def __call__(self, x):
x = self._op(x)
res = (x - self._d*mylog(x)).sum() res = (x - self._d*mylog(x)).sum()
metric = ift.SandwichOperator.make(x.jac, ift.makeOp(1./x.val)) metric = ift.SandwichOperator.make(x.jac, ift.makeOp(1./x.val))
return ift.Linearization(res.val, res.jac, metric) return ift.Linearization(res.val, res.jac, metric)
class MyHamiltonian(ift.Energy): class MyHamiltonian(ift.Operator):
def __init__(self, lh):
super(MyHamiltonian, self).__init__()
self._lh = lh
self._prior = GaussianEnergy2()
pvar = ift.Linearization.make_var(position)
self._res = self._lh(pvar)+self._prior(pvar)
def __call__(self, x):
return self._lh(x) + self._prior(x)
class EnergyAdapter(ift.Energy):
def __init__(self, position, op): def __init__(self, position, op):
super(MyHamiltonian, self).__init__(position) super(EnergyAdapter, self).__init__(position)
self._op = op self._op = op
prior = GaussianEnergy2()
pvar = ift.Linearization.make_var(position) pvar = ift.Linearization.make_var(position)
self._res = op(pvar)+prior(pvar) self._res = op(pvar)
def at(self, position): def at(self, position):
return MyHamiltonian(position, self._op) return EnergyAdapter(position, self._op)
@property @property
def value(self): def value(self):
...@@ -73,15 +85,6 @@ class MyHamiltonian(ift.Energy): ...@@ -73,15 +85,6 @@ class MyHamiltonian(ift.Energy):
def metric(self): def metric(self):
return self._res.metric return self._res.metric
class OperatorSequence(ift.Operator):
def __init__(self, ops):
super(OperatorSequence, self).__init__()
self._ops = ops
def __call__(self, x):
for op in reversed(self._ops):
x = op(x)
return x
def get_2D_exposure(): def get_2D_exposure():
x_shape, y_shape = position_space.shape x_shape, y_shape = position_space.shape
...@@ -149,14 +152,15 @@ if __name__ == '__main__': ...@@ -149,14 +152,15 @@ if __name__ == '__main__':
data = ift.Field.from_global_data(d_space, data) data = ift.Field.from_global_data(d_space, data)
# Compute likelihood and Hamiltonian # Compute likelihood and Hamiltonian
likelihood = PoissonianEnergy2(data) likelihood = PoissonianEnergy2(lamb, data)
ic_cg = ift.GradientNormController(iteration_limit=50) ic_cg = ift.GradientNormController(iteration_limit=50)
ic_newton = ift.GradientNormController(name='Newton', iteration_limit=50, ic_newton = ift.GradientNormController(name='Newton', iteration_limit=50,
tol_abs_gradnorm=1e-3) tol_abs_gradnorm=1e-3)
minimizer = ift.RelaxedNewton(ic_newton) minimizer = ift.RelaxedNewton(ic_newton)
# Minimize the Hamiltonian # Minimize the Hamiltonian
H = MyHamiltonian(position, OperatorSequence([likelihood,lamb])) H = MyHamiltonian(likelihood)
H = EnergyAdapter(position, H)
#ift.extra.check_value_gradient_consistency(H) #ift.extra.check_value_gradient_consistency(H)
H = H.make_invertible(ic_cg) H = H.make_invertible(ic_cg)
H, convergence = minimizer(H) H, convergence = minimizer(H)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment