Commit 72fa347b authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'energy' into 'master'

Energy



See merge request !34
parents 41b9419d 4d0eace3
from nifty import * from nifty import *
import plotly.offline as pl import plotly.offline as pl
import plotly.graph_objs as go import plotly.graph_objs as go
...@@ -8,17 +9,48 @@ comm = MPI.COMM_WORLD ...@@ -8,17 +9,48 @@ comm = MPI.COMM_WORLD
rank = comm.rank rank = comm.rank
class WienerFilterEnergy(Energy):
def __init__(self, position, D, j):
# in principle not necessary, but useful in order to make the signature
# explicit
super(WienerFilterEnergy, self).__init__(position)
self.D = D
self.j = j
def at(self, position):
return self.__class__(position, D=self.D, j=self.j)
@property
def value(self):
D_inv_x = self.D_inverse_x()
H = 0.5 * D_inv_x.dot(self.position) - self.j.dot(self.position)
return H.real
@property
def gradient(self):
D_inv_x = self.D_inverse_x()
g = D_inv_x - self.j
return_g = g.copy_empty(dtype=np.float)
return_g.val = g.val.real
return return_g
@memo
def D_inverse_x(self):
return D.inverse_times(self.position)
if __name__ == "__main__": if __name__ == "__main__":
distribution_strategy = 'fftw' distribution_strategy = 'fftw'
# Set up spaces and fft transformation
s_space = RGSpace([512, 512], dtype=np.float) s_space = RGSpace([512, 512], dtype=np.float)
fft = FFTOperator(s_space) fft = FFTOperator(s_space)
h_space = fft.target[0] h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy) p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
# create the field instances and power operator
pow_spec = (lambda k: (42 / (k + 1) ** 3)) pow_spec = (lambda k: (42 / (k + 1) ** 3))
S = create_power_operator(h_space, power_spectrum=pow_spec, S = create_power_operator(h_space, power_spectrum=pow_spec,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
...@@ -27,8 +59,8 @@ if __name__ == "__main__": ...@@ -27,8 +59,8 @@ if __name__ == "__main__":
sh = sp.power_synthesize(real_signal=True) sh = sp.power_synthesize(real_signal=True)
ss = fft.inverse_times(sh) ss = fft.inverse_times(sh)
# model the measurement process
R = SmoothingOperator(s_space, sigma=0.01) R = SmoothingOperator(s_space, sigma=0.01)
# R = DiagonalOperator(s_space, diagonal=1.) # R = DiagonalOperator(s_space, diagonal=1.)
# R._diagonal.val[200:400, 200:400] = 0 # R._diagonal.val[200:400, 200:400] = 0
...@@ -38,70 +70,67 @@ if __name__ == "__main__": ...@@ -38,70 +70,67 @@ if __name__ == "__main__":
random_type='normal', random_type='normal',
std=ss.std()/np.sqrt(signal_to_noise), std=ss.std()/np.sqrt(signal_to_noise),
mean=0) mean=0)
#n.val.data.imag[:] = 0
# create mock data
d = R(ss) + n d = R(ss) + n
# set up reconstruction objects
j = R.adjoint_times(N.inverse_times(d)) j = R.adjoint_times(N.inverse_times(d))
D = PropagatorOperator(S=S, N=N, R=R) D = PropagatorOperator(S=S, N=N, R=R)
def energy(x): def distance_measure(energy, iteration):
DIx = D.inverse_times(x) x = energy.position
H = 0.5 * DIx.dot(x) - j.dot(x)
return H.real
def gradient(x):
DIx = D.inverse_times(x)
g = DIx - j
return_g = g.copy_empty(dtype=np.float)
return_g.val = g.val.real
return return_g
def distance_measure(x, fgrad, iteration):
print (iteration, ((x-ss).norm()/ss.norm()).real) print (iteration, ((x-ss).norm()/ss.norm()).real)
minimizer = SteepestDescent(convergence_tolerance=0, minimizer = SteepestDescent(convergence_tolerance=0,
iteration_limit=50, iteration_limit=50,
callback=distance_measure) callback=distance_measure)
minimizer = VL_BFGS(convergence_tolerance=0, minimizer = VL_BFGS(convergence_tolerance=0,
iteration_limit=50, iteration_limit=50,
callback=distance_measure, callback=distance_measure,
max_history_length=5) max_history_length=3)
m0 = Field(s_space, val=1) m0 = Field(s_space, val=1)
(m, convergence) = minimizer(m0, energy, gradient) energy = WienerFilterEnergy(position=m0, D=D, j=j)
(energy, convergence) = minimizer(energy)
grad = gradient(m)
d_data = d.val.get_full_data().real
if rank == 0: #
pl.plot([go.Heatmap(z=d_data)], filename='data.html') #
#
# grad = gradient(m)
ss_data = ss.val.get_full_data().real #
if rank == 0: # d_data = d.val.get_full_data().real
pl.plot([go.Heatmap(z=ss_data)], filename='ss.html') # if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
sh_data = sh.val.get_full_data().real #
if rank == 0: #
pl.plot([go.Heatmap(z=sh_data)], filename='sh.html') # ss_data = ss.val.get_full_data().real
# if rank == 0:
j_data = j.val.get_full_data().real # pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
if rank == 0: #
pl.plot([go.Heatmap(z=j_data)], filename='j.html') # sh_data = sh.val.get_full_data().real
# if rank == 0:
jabs_data = np.abs(j.val.get_full_data()) # pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
jphase_data = np.angle(j.val.get_full_data()) #
if rank == 0: # j_data = j.val.get_full_data().real
pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html') # if rank == 0:
pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html') # pl.plot([go.Heatmap(z=j_data)], filename='j.html')
#
m_data = m.val.get_full_data().real # jabs_data = np.abs(j.val.get_full_data())
if rank == 0: # jphase_data = np.angle(j.val.get_full_data())
pl.plot([go.Heatmap(z=m_data)], filename='map.html') # if rank == 0:
# pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
grad_data = grad.val.get_full_data().real # pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
if rank == 0: #
pl.plot([go.Heatmap(z=grad_data)], filename='grad.html') # m_data = m.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
#
# grad_data = grad.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')
...@@ -40,6 +40,8 @@ from config import dependency_injector,\ ...@@ -40,6 +40,8 @@ from config import dependency_injector,\
from d2o import distributed_data_object, d2o_librarian from d2o import distributed_data_object, d2o_librarian
from energies import *
from field import Field from field import Field
from random import Random from random import Random
......
# -*- coding: utf-8 -*-
from energy import Energy
from line_energy import LineEnergy
from memoization import memo
# -*- coding: utf-8 -*-
class Energy(object):
def __init__(self, position):
self._cache = {}
try:
position = position.copy()
except AttributeError:
pass
self.position = position
def at(self, position):
return self.__class__(position)
@property
def value(self):
raise NotImplementedError
@property
def gradient(self):
raise NotImplementedError
@property
def curvature(self):
raise NotImplementedError
# -*- coding: utf-8 -*-
from .energy import Energy
class LineEnergy(Energy):
def __init__(self, position, energy, line_direction, zero_point=None):
super(LineEnergy, self).__init__(position=position)
self.line_direction = line_direction
if zero_point is None:
zero_point = energy.position
self._zero_point = zero_point
position_on_line = self._zero_point + self.position*line_direction
self.energy = energy.at(position=position_on_line)
def at(self, position):
return self.__class__(position,
self.energy,
self.line_direction,
zero_point=self._zero_point)
@property
def value(self):
return self.energy.value
@property
def gradient(self):
return self.energy.gradient.dot(self.line_direction)
@property
def curvature(self):
return self.energy.curvature
# -*- coding: utf-8 -*-
def memo(f):
name = id(f)
def wrapped_f(self):
try:
return self._cache[name]
except KeyError:
self._cache[name] = f(self)
return self._cache[name]
return wrapped_f
...@@ -2,6 +2,8 @@ import abc ...@@ -2,6 +2,8 @@ import abc
from keepers import Loggable from keepers import Loggable
from nifty import LineEnergy
class LineSearch(object, Loggable): class LineSearch(object, Loggable):
""" """
...@@ -26,23 +28,11 @@ class LineSearch(object, Loggable): ...@@ -26,23 +28,11 @@ class LineSearch(object, Loggable):
derivation. derivation.
""" """
self.xk = None
self.pk = None self.pk = None
self.line_energy = None
self.f_k = None
self.f_k_minus_1 = None self.f_k_minus_1 = None
self.fprime_k = None
def set_functions(self, f, fprime, f_args=()):
assert(callable(f))
assert(callable(fprime))
self.f = f def _set_line_energy(self, energy, pk, f_k_minus_1=None):
self.fprime = fprime
self.f_args = f_args
def _set_coordinates(self, xk, pk, f_k=None, fprime_k=None,
f_k_minus_1=None):
""" """
Set the coordinates for a new line search. Set the coordinates for a new line search.
...@@ -61,40 +51,13 @@ class LineSearch(object, Loggable): ...@@ -61,40 +51,13 @@ class LineSearch(object, Loggable):
Function value fprime(xk). Function value fprime(xk).
""" """
self.line_energy = LineEnergy(position=0.,
self.xk = xk.copy() energy=energy,
self.pk = pk.copy() line_direction=pk)
if f_k is None:
self.f_k = self.f(xk)
else:
self.f_k = f_k
if fprime_k is None:
self.fprime_k = self.fprime(xk)
else:
self.fprime_k = fprime_k
if f_k_minus_1 is not None: if f_k_minus_1 is not None:
f_k_minus_1 = f_k_minus_1.copy() f_k_minus_1 = f_k_minus_1.copy()
self.f_k_minus_1 = f_k_minus_1 self.f_k_minus_1 = f_k_minus_1
def _phi(self, alpha):
if alpha == 0:
value = self.f_k
else:
value = self.f(self.xk + self.pk*alpha, *self.f_args)
return value
def _phiprime(self, alpha):
if alpha == 0:
gradient = self.fprime_k
else:
gradient = self.fprime(self.xk + self.pk*alpha, *self.f_args)
return gradient.dot(self.pk)
@abc.abstractmethod @abc.abstractmethod
def perform_line_search(self, xk, pk, f_k=None, fprime_k=None, def perform_line_search(self, energy, pk, f_k_minus_1=None):
f_k_minus_1=None):
raise NotImplementedError raise NotImplementedError
...@@ -45,13 +45,8 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -45,13 +45,8 @@ class LineSearchStrongWolfe(LineSearch):
self.max_zoom_iterations = int(max_zoom_iterations) self.max_zoom_iterations = int(max_zoom_iterations)
self._last_alpha_star = 1. self._last_alpha_star = 1.
def perform_line_search(self, xk, pk, f_k=None, fprime_k=None, def perform_line_search(self, energy, pk, f_k_minus_1=None):
f_k_minus_1=None): self._set_line_energy(energy, pk, f_k_minus_1=f_k_minus_1)
self._set_coordinates(xk=xk,
pk=pk,
f_k=f_k,
fprime_k=fprime_k,
f_k_minus_1=f_k_minus_1)
c1 = self.c1 c1 = self.c1
c2 = self.c2 c2 = self.c2
max_step_size = self.max_step_size max_step_size = self.max_step_size
...@@ -59,8 +54,9 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -59,8 +54,9 @@ class LineSearchStrongWolfe(LineSearch):
# initialize the zero phis # initialize the zero phis
old_phi_0 = self.f_k_minus_1 old_phi_0 = self.f_k_minus_1
phi_0 = self._phi(0.) energy_0 = self.line_energy.at(0)
phiprime_0 = self._phiprime(0.) phi_0 = energy_0.value
phiprime_0 = energy_0.gradient
if phiprime_0 == 0: if phiprime_0 == 0:
self.logger.warn("Flat gradient in search direction.") self.logger.warn("Flat gradient in search direction.")
...@@ -81,16 +77,19 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -81,16 +77,19 @@ class LineSearchStrongWolfe(LineSearch):
# start the minimization loop # start the minimization loop
for i in xrange(max_iterations): for i in xrange(max_iterations):
phi_alpha1 = self._phi(alpha1) energy_alpha1 = self.line_energy.at(alpha1)
phi_alpha1 = energy_alpha1.value
if alpha1 == 0: if alpha1 == 0:
self.logger.warn("Increment size became 0.") self.logger.warn("Increment size became 0.")
alpha_star = 0. alpha_star = 0.
phi_star = phi_0 phi_star = phi_0
energy_star = energy_0
break break
if (phi_alpha1 > phi_0 + c1*alpha1*phiprime_0) or \ if (phi_alpha1 > phi_0 + c1*alpha1*phiprime_0) or \
((phi_alpha1 >= phi_alpha0) and (i > 1)): ((phi_alpha1 >= phi_alpha0) and (i > 1)):
(alpha_star, phi_star) = self._zoom(alpha0, alpha1, (alpha_star, phi_star, energy_star) = self._zoom(
alpha0, alpha1,
phi_0, phiprime_0, phi_0, phiprime_0,
phi_alpha0, phi_alpha0,
phiprime_alpha0, phiprime_alpha0,
...@@ -98,14 +97,16 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -98,14 +97,16 @@ class LineSearchStrongWolfe(LineSearch):
c1, c2) c1, c2)
break break
phiprime_alpha1 = self._phiprime(alpha1) phiprime_alpha1 = energy_alpha1.gradient
if abs(phiprime_alpha1) <= -c2*phiprime_0: if abs(phiprime_alpha1) <= -c2*phiprime_0:
alpha_star = alpha1 alpha_star = alpha1
phi_star = phi_alpha1 phi_star = phi_alpha1
energy_star = energy_alpha1
break break
if phiprime_alpha1 >= 0: if phiprime_alpha1 >= 0:
(alpha_star, phi_star) = self._zoom(alpha1, alpha0, (alpha_star, phi_star, energy_star) = self._zoom(
alpha1, alpha0,
phi_0, phiprime_0, phi_0, phiprime_0,
phi_alpha1, phi_alpha1,
phiprime_alpha1, phiprime_alpha1,
...@@ -123,10 +124,15 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -123,10 +124,15 @@ class LineSearchStrongWolfe(LineSearch):
# max_iterations was reached # max_iterations was reached
alpha_star = alpha1 alpha_star = alpha1
phi_star = phi_alpha1 phi_star = phi_alpha1
energy_star = energy_alpha1
self.logger.error("The line search algorithm did not converge.") self.logger.error("The line search algorithm did not converge.")
self._last_alpha_star = alpha_star self._last_alpha_star = alpha_star
return alpha_star, phi_star
# extract the full energy from the line_energy
energy_star = energy_star.energy
return alpha_star, phi_star, energy_star
def _zoom(self, alpha_lo, alpha_hi, phi_0, phiprime_0, def _zoom(self, alpha_lo, alpha_hi, phi_0, phiprime_0,
phi_lo, phiprime_lo, phi_hi, c1, c2): phi_lo, phiprime_lo, phi_hi, c1, c2):
...@@ -165,7 +171,8 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -165,7 +171,8 @@ class LineSearchStrongWolfe(LineSearch):
alpha_j = alpha_lo + 0.5*delta_alpha alpha_j = alpha_lo + 0.5*delta_alpha
# Check if the current value of alpha_j is already sufficient # Check if the current value of alpha_j is already sufficient
phi_alphaj = self._phi(alpha_j) energy_alphaj = self.line_energy.at(alpha_j)
phi_alphaj = energy_alphaj.value
# If the first Wolfe condition is not met replace alpha_hi # If the first Wolfe condition is not met replace alpha_hi
# by alpha_j # by alpha_j
...@@ -174,11 +181,12 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -174,11 +181,12 @@ class LineSearchStrongWolfe(LineSearch):
alpha_recent, phi_recent = alpha_hi, phi_hi alpha_recent, phi_recent = alpha_hi, phi_hi
alpha_hi, phi_hi = alpha_j, phi_alphaj alpha_hi, phi_hi = alpha_j, phi_alphaj
else: else:
phiprime_alphaj = self._phiprime(alpha_j) phiprime_alphaj = energy_alphaj.gradient
# If the second Wolfe condition is met, return the result # If the second Wolfe condition is met, return the result
if abs(phiprime_alphaj) <= -c2*phiprime_0: if abs(phiprime_alphaj) <= -c2*phiprime_0:
alpha_star = alpha_j alpha_star = alpha_j
phi_star = phi_alphaj phi_star = phi_alphaj
energy_star = energy_alphaj
break break
# If not, check the sign of the slope # If not, check the sign of the slope
if phiprime_alphaj*delta_alpha >= 0: if phiprime_alphaj*delta_alpha >= 0:
...@@ -191,11 +199,12 @@ class LineSearchStrongWolfe(LineSearch): ...@@ -191,11 +199,12 @@ class LineSearchStrongWolfe(LineSearch):
phiprime_alphaj) phiprime_alphaj)
else: else:
alpha_star, phi_star = alpha_j, phi_alphaj alpha_star, phi_star, energy_star = \
alpha_j, phi_alphaj, energy_alphaj
self.logger.error("The line search algorithm (zoom) did not " self.logger.error("The line search algorithm (zoom) did not "
"converge.") "converge.")
return alpha_star, phi_star return alpha_star, phi_star, energy_star
def _cubicmin(self, a, fa, fpa, b, fb, c, fc): def _cubicmin(self, a, fa, fpa, b, fb, c, fc):
""" """
......
...@@ -26,7 +26,7 @@ class QuasiNewtonMinimizer(object, Loggable): ...@@ -26,7 +26,7 @@ class QuasiNewtonMinimizer(object, Loggable):
self.line_searcher = line_searcher self.line_searcher = line_searcher
self.callback = callback self.callback = callback
def __call__(self, x0, f, fprime, f_args=()): def __call__(self, energy):
""" """
Runs the steepest descent minimization. Runs the steepest descent minimization.
...@@ -56,49 +56,43 @@ class QuasiNewtonMinimizer(object, Loggable): ...@@ -56,49 +56,43 @@ class QuasiNewtonMinimizer(object, Loggable):