Commit 17b1c6e9 authored by theos's avatar theos

Implemented Energy object. Adapted minimizers to it.

parent 41b9419d
from nifty import *
import plotly.offline as pl
import plotly.graph_objs as go
......@@ -8,17 +9,47 @@ comm = MPI.COMM_WORLD
rank = comm.rank
class WienerFilterEnergy(Energy):
def __init__(self, position, D, j):
# in principle not necessary, but useful in order to make the signature
# explicit
super(WienerFilterEnergy, self).__init__(position)
self.D = D
self.j = j
def at(self, position):
return self.__class__(position, D=self.D, j=self.j)
@property
def value(self):
D_inv_x = self.D_inverse_x()
H = 0.5 * D_inv_x.dot(self.position) - self.j.dot(self.position)
return H.real
@property
def gradient(self):
D_inv_x = self.D_inverse_x()
g = D_inv_x - self.j
return_g = g.copy_empty(dtype=np.float)
return_g.val = g.val.real
return return_g
def D_inverse_x(self):
return D.inverse_times(self.position)
if __name__ == "__main__":
distribution_strategy = 'fftw'
# Set up spaces and fft transformation
s_space = RGSpace([512, 512], dtype=np.float)
fft = FFTOperator(s_space)
h_space = fft.target[0]
p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy)
# create the field instances and power operator
pow_spec = (lambda k: (42 / (k + 1) ** 3))
S = create_power_operator(h_space, power_spectrum=pow_spec,
distribution_strategy=distribution_strategy)
......@@ -27,8 +58,8 @@ if __name__ == "__main__":
sh = sp.power_synthesize(real_signal=True)
ss = fft.inverse_times(sh)
# model the measurement process
R = SmoothingOperator(s_space, sigma=0.01)
# R = DiagonalOperator(s_space, diagonal=1.)
# R._diagonal.val[200:400, 200:400] = 0
......@@ -38,70 +69,67 @@ if __name__ == "__main__":
random_type='normal',
std=ss.std()/np.sqrt(signal_to_noise),
mean=0)
#n.val.data.imag[:] = 0
# create mock data
d = R(ss) + n
# set up reconstruction objects
j = R.adjoint_times(N.inverse_times(d))
D = PropagatorOperator(S=S, N=N, R=R)
def energy(x):
DIx = D.inverse_times(x)
H = 0.5 * DIx.dot(x) - j.dot(x)
return H.real
def gradient(x):
DIx = D.inverse_times(x)
g = DIx - j
return_g = g.copy_empty(dtype=np.float)
return_g.val = g.val.real
return return_g
def distance_measure(x, fgrad, iteration):
print (iteration, ((x-ss).norm()/ss.norm()).real)
def distance_measure(energy, iteration):
pass
#print (iteration, ((x-ss).norm()/ss.norm()).real)
minimizer = SteepestDescent(convergence_tolerance=0,
iteration_limit=50,
callback=distance_measure)
minimizer = VL_BFGS(convergence_tolerance=0,
iteration_limit=50,
callback=distance_measure,
max_history_length=5)
# minimizer = VL_BFGS(convergence_tolerance=0,
# iteration_limit=50,
# callback=distance_measure,
# max_history_length=5)
m0 = Field(s_space, val=1)
(m, convergence) = minimizer(m0, energy, gradient)
grad = gradient(m)
d_data = d.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
ss_data = ss.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
sh_data = sh.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
j_data = j.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=j_data)], filename='j.html')
jabs_data = np.abs(j.val.get_full_data())
jphase_data = np.angle(j.val.get_full_data())
if rank == 0:
pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
m_data = m.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=m_data)], filename='map.html')
grad_data = grad.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')
energy = WienerFilterEnergy(position=m0, D=D, j=j)
(energy, convergence) = minimizer(energy)
#
#
#
# grad = gradient(m)
#
# d_data = d.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
#
#
# ss_data = ss.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
#
# sh_data = sh.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
#
# j_data = j.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=j_data)], filename='j.html')
#
# jabs_data = np.abs(j.val.get_full_data())
# jphase_data = np.angle(j.val.get_full_data())
# if rank == 0:
# pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
# pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
#
# m_data = m.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
#
# grad_data = grad.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')
......@@ -40,6 +40,8 @@ from config import dependency_injector,\
from d2o import distributed_data_object, d2o_librarian
from energies import *
from field import Field
from random import Random
......
# -*- coding: utf-8 -*-
from energy import Energy
from line_energy import LineEnergy
# -*- coding: utf-8 -*-
class Energy(object):
def __init__(self, position):
self._cache = {}
try:
position = position.copy()
except AttributeError:
pass
self.position = position
def at(self, position):
return self.__class__(position)
@property
def value(self):
raise NotImplementedError
@property
def gradient(self):
raise NotImplementedError
@property
def curvature(self):
raise NotImplementedError
def memo(f):
name = id(f)
def wrapped_f(self):
try:
return self._cache[name]
except KeyError:
self._cache[name] = f(self)
return self._cache[name]
return wrapped_f
# -*- coding: utf-8 -*-
from .energy import Energy
class LineEnergy(Energy):
def __init__(self, position, energy, line_direction):
self.energy = energy
self.line_direction = line_direction
super(LineEnergy, self).__init__(position=position)
def at(self, position):
if position == 0:
return self
else:
full_position = self.position + self.line_direction*position
return self.__class__(full_position,
self.energy,
self.line_direction)
@property
def value(self):
return self.energy.value
@property
def gradient(self):
return self.energy.gradient.dot(self.line_direction)
@property
def curvature(self):
return self.energy.curvature
......@@ -2,6 +2,8 @@ import abc
from keepers import Loggable
from nifty import LineEnergy
class LineSearch(object, Loggable):
"""
......@@ -26,23 +28,11 @@ class LineSearch(object, Loggable):
derivation.
"""
self.xk = None
self.pk = None
self.f_k = None
self.line_energy = None
self.f_k_minus_1 = None
self.fprime_k = None
def set_functions(self, f, fprime, f_args=()):
assert(callable(f))
assert(callable(fprime))
self.f = f
self.fprime = fprime
self.f_args = f_args
def _set_coordinates(self, xk, pk, f_k=None, fprime_k=None,
f_k_minus_1=None):
def _set_line_energy(self, energy, pk, f_k_minus_1=None):
"""
Set the coordinates for a new line search.
......@@ -61,39 +51,13 @@ class LineSearch(object, Loggable):
Function value fprime(xk).
"""
self.xk = xk.copy()
self.pk = pk.copy()
if f_k is None:
self.f_k = self.f(xk)
else:
self.f_k = f_k
if fprime_k is None:
self.fprime_k = self.fprime(xk)
else:
self.fprime_k = fprime_k
self.line_energy = LineEnergy(position=0.,
energy=energy,
line_direction=pk)
if f_k_minus_1 is not None:
f_k_minus_1 = f_k_minus_1.copy()
self.f_k_minus_1 = f_k_minus_1
def _phi(self, alpha):
if alpha == 0:
value = self.f_k
else:
value = self.f(self.xk + self.pk*alpha, *self.f_args)
return value
def _phiprime(self, alpha):
if alpha == 0:
gradient = self.fprime_k
else:
gradient = self.fprime(self.xk + self.pk*alpha, *self.f_args)
return gradient.dot(self.pk)
@abc.abstractmethod
def perform_line_search(self, xk, pk, f_k=None, fprime_k=None,
f_k_minus_1=None):
......
......@@ -45,13 +45,8 @@ class LineSearchStrongWolfe(LineSearch):
self.max_zoom_iterations = int(max_zoom_iterations)
self._last_alpha_star = 1.
def perform_line_search(self, xk, pk, f_k=None, fprime_k=None,
f_k_minus_1=None):
self._set_coordinates(xk=xk,
pk=pk,
f_k=f_k,
fprime_k=fprime_k,
f_k_minus_1=f_k_minus_1)
def perform_line_search(self, energy, pk, f_k_minus_1=None):
self._set_line_energy(energy, pk, f_k_minus_1=f_k_minus_1)
c1 = self.c1
c2 = self.c2
max_step_size = self.max_step_size
......@@ -59,8 +54,8 @@ class LineSearchStrongWolfe(LineSearch):
# initialize the zero phis
old_phi_0 = self.f_k_minus_1
phi_0 = self._phi(0.)
phiprime_0 = self._phiprime(0.)
phi_0 = self.line_energy.at(0).value
phiprime_0 = self.line_energy.at(0).gradient
if phiprime_0 == 0:
self.logger.warn("Flat gradient in search direction.")
......@@ -81,7 +76,8 @@ class LineSearchStrongWolfe(LineSearch):
# start the minimization loop
for i in xrange(max_iterations):
phi_alpha1 = self._phi(alpha1)
energy_alpha1 = self.line_energy.at(alpha1)
phi_alpha1 = energy_alpha1.value
if alpha1 == 0:
self.logger.warn("Increment size became 0.")
alpha_star = 0.
......@@ -98,7 +94,7 @@ class LineSearchStrongWolfe(LineSearch):
c1, c2)
break
phiprime_alpha1 = self._phiprime(alpha1)
phiprime_alpha1 = energy_alpha1.gradient
if abs(phiprime_alpha1) <= -c2*phiprime_0:
alpha_star = alpha1
phi_star = phi_alpha1
......@@ -165,7 +161,8 @@ class LineSearchStrongWolfe(LineSearch):
alpha_j = alpha_lo + 0.5*delta_alpha
# Check if the current value of alpha_j is already sufficient
phi_alphaj = self._phi(alpha_j)
energy_alphaj = self.line_energy.at(alpha_j)
phi_alphaj = energy_alphaj.value
# If the first Wolfe condition is not met replace alpha_hi
# by alpha_j
......@@ -174,7 +171,7 @@ class LineSearchStrongWolfe(LineSearch):
alpha_recent, phi_recent = alpha_hi, phi_hi
alpha_hi, phi_hi = alpha_j, phi_alphaj
else:
phiprime_alphaj = self._phiprime(alpha_j)
phiprime_alphaj = energy_alphaj.gradient
# If the second Wolfe condition is met, return the result
if abs(phiprime_alphaj) <= -c2*phiprime_0:
alpha_star = alpha_j
......
......@@ -26,7 +26,7 @@ class QuasiNewtonMinimizer(object, Loggable):
self.line_searcher = line_searcher
self.callback = callback
def __call__(self, x0, f, fprime, f_args=()):
def __call__(self, energy):
"""
Runs the steepest descent minimization.
......@@ -56,49 +56,45 @@ class QuasiNewtonMinimizer(object, Loggable):
"""
x = x0.copy()
self.line_searcher.set_functions(f=f, fprime=fprime, f_args=f_args)
convergence = 0
f_k_minus_1 = None
f_k = f(x)
step_length = 0
iteration_number = 1
while True:
if self.callback is not None:
try:
self.callback(x, f_k, iteration_number)
self.callback(energy, iteration_number)
except StopIteration:
self.logger.info("Minimization was stopped by callback "
"function.")
break
# compute the the gradient for the current x
gradient = fprime(x)
# compute the the gradient for the current location
gradient = energy.gradient
gradient_norm = gradient.dot(gradient)
# check if x is at a flat point
# check if position is at a flat point
if gradient_norm == 0:
self.logger.info("Reached perfectly flat point. Stopping.")
convergence = self.convergence_level+2
break
descend_direction = self._get_descend_direction(x, gradient)
current_position = energy.position
descend_direction = self._get_descend_direction(current_position,
gradient)
# compute the step length, which minimizes f_k along the
# search direction = the gradient
step_length, new_f_k = self.line_searcher.perform_line_search(
xk=x,
# compute the step length, which minimizes energy.value along the
# search direction
step_length, step_length = self.line_searcher.perform_line_search(
energy=energy,
pk=descend_direction,
f_k=f_k,
fprime_k=gradient,
f_k_minus_1=f_k_minus_1)
f_k_minus_1 = f_k
f_k = new_f_k
new_position = current_position + step_length * descend_direction
new_energy = energy.at(new_position)
# update x
x += descend_direction*step_length
f_k_minus_1 = energy.value
energy = new_energy
# check convergence
delta = abs(gradient).max() * (step_length/gradient_norm)
......@@ -127,7 +123,7 @@ class QuasiNewtonMinimizer(object, Loggable):
iteration_number += 1
return x, convergence
return energy, convergence
@abc.abstractmethod
def _get_descend_direction(self, gradient, gradient_norm):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment