diff --git a/demos/wiener_filter_hamiltonian.py b/demos/wiener_filter_hamiltonian.py index 71c4dce564dca68810f3332c5b563716b934f7ba..118bc6e693812f1490483b78eac03775aa7f52e7 100644 --- a/demos/wiener_filter_hamiltonian.py +++ b/demos/wiener_filter_hamiltonian.py @@ -1,5 +1,6 @@ from nifty import * + import plotly.offline as pl import plotly.graph_objs as go @@ -8,17 +9,47 @@ comm = MPI.COMM_WORLD rank = comm.rank +class WienerFilterEnergy(Energy): + def __init__(self, position, D, j): + # in principle not necessary, but useful in order to make the signature + # explicit + super(WienerFilterEnergy, self).__init__(position) + self.D = D + self.j = j + + def at(self, position): + return self.__class__(position, D=self.D, j=self.j) + + @property + def value(self): + D_inv_x = self.D_inverse_x() + H = 0.5 * D_inv_x.dot(self.position) - self.j.dot(self.position) + return H.real + + @property + def gradient(self): + D_inv_x = self.D_inverse_x() + g = D_inv_x - self.j + return_g = g.copy_empty(dtype=np.float) + return_g.val = g.val.real + return return_g + + def D_inverse_x(self): + return D.inverse_times(self.position) + + if __name__ == "__main__": distribution_strategy = 'fftw' + # Set up spaces and fft transformation s_space = RGSpace([512, 512], dtype=np.float) fft = FFTOperator(s_space) h_space = fft.target[0] p_space = PowerSpace(h_space, distribution_strategy=distribution_strategy) + # create the field instances and power operator pow_spec = (lambda k: (42 / (k + 1) ** 3)) - S = create_power_operator(h_space, power_spectrum=pow_spec, distribution_strategy=distribution_strategy) @@ -27,8 +58,8 @@ if __name__ == "__main__": sh = sp.power_synthesize(real_signal=True) ss = fft.inverse_times(sh) + # model the measurement process R = SmoothingOperator(s_space, sigma=0.01) - # R = DiagonalOperator(s_space, diagonal=1.) # R._diagonal.val[200:400, 200:400] = 0 @@ -38,70 +69,67 @@ if __name__ == "__main__": random_type='normal', std=ss.std()/np.sqrt(signal_to_noise), mean=0) - #n.val.data.imag[:] = 0 + # create mock data d = R(ss) + n + + # set up reconstruction objects j = R.adjoint_times(N.inverse_times(d)) D = PropagatorOperator(S=S, N=N, R=R) - def energy(x): - DIx = D.inverse_times(x) - H = 0.5 * DIx.dot(x) - j.dot(x) - return H.real - - def gradient(x): - DIx = D.inverse_times(x) - g = DIx - j - return_g = g.copy_empty(dtype=np.float) - return_g.val = g.val.real - return return_g - - def distance_measure(x, fgrad, iteration): - print (iteration, ((x-ss).norm()/ss.norm()).real) + def distance_measure(energy, iteration): + pass + #print (iteration, ((x-ss).norm()/ss.norm()).real) minimizer = SteepestDescent(convergence_tolerance=0, iteration_limit=50, callback=distance_measure) - minimizer = VL_BFGS(convergence_tolerance=0, - iteration_limit=50, - callback=distance_measure, - max_history_length=5) +# minimizer = VL_BFGS(convergence_tolerance=0, +# iteration_limit=50, +# callback=distance_measure, +# max_history_length=5) m0 = Field(s_space, val=1) - (m, convergence) = minimizer(m0, energy, gradient) - - - grad = gradient(m) - - d_data = d.val.get_full_data().real - if rank == 0: - pl.plot([go.Heatmap(z=d_data)], filename='data.html') - - - ss_data = ss.val.get_full_data().real - if rank == 0: - pl.plot([go.Heatmap(z=ss_data)], filename='ss.html') - - sh_data = sh.val.get_full_data().real - if rank == 0: - pl.plot([go.Heatmap(z=sh_data)], filename='sh.html') - - j_data = j.val.get_full_data().real - if rank == 0: - pl.plot([go.Heatmap(z=j_data)], filename='j.html') - - jabs_data = np.abs(j.val.get_full_data()) - jphase_data = np.angle(j.val.get_full_data()) - if rank == 0: - pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html') - pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html') - - m_data = m.val.get_full_data().real - if rank == 0: - pl.plot([go.Heatmap(z=m_data)], filename='map.html') - - grad_data = grad.val.get_full_data().real - if rank == 0: - pl.plot([go.Heatmap(z=grad_data)], filename='grad.html') + energy = WienerFilterEnergy(position=m0, D=D, j=j) + + (energy, convergence) = minimizer(energy) + + + +# +# +# +# grad = gradient(m) +# +# d_data = d.val.get_full_data().real +# if rank == 0: +# pl.plot([go.Heatmap(z=d_data)], filename='data.html') +# +# +# ss_data = ss.val.get_full_data().real +# if rank == 0: +# pl.plot([go.Heatmap(z=ss_data)], filename='ss.html') +# +# sh_data = sh.val.get_full_data().real +# if rank == 0: +# pl.plot([go.Heatmap(z=sh_data)], filename='sh.html') +# +# j_data = j.val.get_full_data().real +# if rank == 0: +# pl.plot([go.Heatmap(z=j_data)], filename='j.html') +# +# jabs_data = np.abs(j.val.get_full_data()) +# jphase_data = np.angle(j.val.get_full_data()) +# if rank == 0: +# pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html') +# pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html') +# +# m_data = m.val.get_full_data().real +# if rank == 0: +# pl.plot([go.Heatmap(z=m_data)], filename='map.html') +# +# grad_data = grad.val.get_full_data().real +# if rank == 0: +# pl.plot([go.Heatmap(z=grad_data)], filename='grad.html') diff --git a/nifty/__init__.py b/nifty/__init__.py index c73c9628857770e9fe7d3ba95e5aa24afd11b178..bc16738c55f251e07685a49c68b67c94fb4284c6 100644 --- a/nifty/__init__.py +++ b/nifty/__init__.py @@ -40,6 +40,8 @@ from config import dependency_injector,\ from d2o import distributed_data_object, d2o_librarian +from energies import * + from field import Field from random import Random diff --git a/nifty/energies/__init__.py b/nifty/energies/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f5356d255bde20b19f743aa97674b5b4d83c584 --- /dev/null +++ b/nifty/energies/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from energy import Energy +from line_energy import LineEnergy diff --git a/nifty/energies/energy.py b/nifty/energies/energy.py new file mode 100644 index 0000000000000000000000000000000000000000..325e799507453180d9e892934fce56948cb188d5 --- /dev/null +++ b/nifty/energies/energy.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + + +class Energy(object): + def __init__(self, position): + self._cache = {} + try: + position = position.copy() + except AttributeError: + pass + self.position = position + + def at(self, position): + return self.__class__(position) + + @property + def value(self): + raise NotImplementedError + + @property + def gradient(self): + raise NotImplementedError + + @property + def curvature(self): + raise NotImplementedError + + def memo(f): + name = id(f) + + def wrapped_f(self): + try: + return self._cache[name] + except KeyError: + self._cache[name] = f(self) + return self._cache[name] + return wrapped_f + diff --git a/nifty/energies/line_energy.py b/nifty/energies/line_energy.py new file mode 100644 index 0000000000000000000000000000000000000000..6268553e98fe321904f16fa790ed5cac879a7e7e --- /dev/null +++ b/nifty/energies/line_energy.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +from .energy import Energy + + +class LineEnergy(Energy): + def __init__(self, position, energy, line_direction): + self.energy = energy + self.line_direction = line_direction + super(LineEnergy, self).__init__(position=position) + + def at(self, position): + if position == 0: + return self + else: + full_position = self.position + self.line_direction*position + return self.__class__(full_position, + self.energy, + self.line_direction) + + @property + def value(self): + return self.energy.value + + @property + def gradient(self): + return self.energy.gradient.dot(self.line_direction) + + @property + def curvature(self): + return self.energy.curvature diff --git a/nifty/minimization/line_searching/line_search.py b/nifty/minimization/line_searching/line_search.py index 4f28e702f55c689f49ae255877054a99826bdef5..b9d68a6754b278f30fca794784ddcbc325d00ec8 100644 --- a/nifty/minimization/line_searching/line_search.py +++ b/nifty/minimization/line_searching/line_search.py @@ -2,6 +2,8 @@ import abc from keepers import Loggable +from nifty import LineEnergy + class LineSearch(object, Loggable): """ @@ -26,23 +28,11 @@ class LineSearch(object, Loggable): derivation. """ - self.xk = None self.pk = None - - self.f_k = None + self.line_energy = None self.f_k_minus_1 = None - self.fprime_k = None - - def set_functions(self, f, fprime, f_args=()): - assert(callable(f)) - assert(callable(fprime)) - self.f = f - self.fprime = fprime - self.f_args = f_args - - def _set_coordinates(self, xk, pk, f_k=None, fprime_k=None, - f_k_minus_1=None): + def _set_line_energy(self, energy, pk, f_k_minus_1=None): """ Set the coordinates for a new line search. @@ -61,39 +51,13 @@ class LineSearch(object, Loggable): Function value fprime(xk). """ - - self.xk = xk.copy() - self.pk = pk.copy() - - if f_k is None: - self.f_k = self.f(xk) - else: - self.f_k = f_k - - if fprime_k is None: - self.fprime_k = self.fprime(xk) - else: - self.fprime_k = fprime_k - + self.line_energy = LineEnergy(position=0., + energy=energy, + line_direction=pk) if f_k_minus_1 is not None: f_k_minus_1 = f_k_minus_1.copy() self.f_k_minus_1 = f_k_minus_1 - def _phi(self, alpha): - if alpha == 0: - value = self.f_k - else: - value = self.f(self.xk + self.pk*alpha, *self.f_args) - return value - - def _phiprime(self, alpha): - if alpha == 0: - gradient = self.fprime_k - else: - gradient = self.fprime(self.xk + self.pk*alpha, *self.f_args) - - return gradient.dot(self.pk) - @abc.abstractmethod def perform_line_search(self, xk, pk, f_k=None, fprime_k=None, f_k_minus_1=None): diff --git a/nifty/minimization/line_searching/line_search_strong_wolfe.py b/nifty/minimization/line_searching/line_search_strong_wolfe.py index 602dceb209e9b91d0adc531964ca2245984ca9ea..ac08d3d18927ec77f0e23caa0a4a1f8a9badaae8 100644 --- a/nifty/minimization/line_searching/line_search_strong_wolfe.py +++ b/nifty/minimization/line_searching/line_search_strong_wolfe.py @@ -45,13 +45,8 @@ class LineSearchStrongWolfe(LineSearch): self.max_zoom_iterations = int(max_zoom_iterations) self._last_alpha_star = 1. - def perform_line_search(self, xk, pk, f_k=None, fprime_k=None, - f_k_minus_1=None): - self._set_coordinates(xk=xk, - pk=pk, - f_k=f_k, - fprime_k=fprime_k, - f_k_minus_1=f_k_minus_1) + def perform_line_search(self, energy, pk, f_k_minus_1=None): + self._set_line_energy(energy, pk, f_k_minus_1=f_k_minus_1) c1 = self.c1 c2 = self.c2 max_step_size = self.max_step_size @@ -59,8 +54,8 @@ class LineSearchStrongWolfe(LineSearch): # initialize the zero phis old_phi_0 = self.f_k_minus_1 - phi_0 = self._phi(0.) - phiprime_0 = self._phiprime(0.) + phi_0 = self.line_energy.at(0).value + phiprime_0 = self.line_energy.at(0).gradient if phiprime_0 == 0: self.logger.warn("Flat gradient in search direction.") @@ -81,7 +76,8 @@ class LineSearchStrongWolfe(LineSearch): # start the minimization loop for i in xrange(max_iterations): - phi_alpha1 = self._phi(alpha1) + energy_alpha1 = self.line_energy.at(alpha1) + phi_alpha1 = energy_alpha1.value if alpha1 == 0: self.logger.warn("Increment size became 0.") alpha_star = 0. @@ -98,7 +94,7 @@ class LineSearchStrongWolfe(LineSearch): c1, c2) break - phiprime_alpha1 = self._phiprime(alpha1) + phiprime_alpha1 = energy_alpha1.gradient if abs(phiprime_alpha1) <= -c2*phiprime_0: alpha_star = alpha1 phi_star = phi_alpha1 @@ -165,7 +161,8 @@ class LineSearchStrongWolfe(LineSearch): alpha_j = alpha_lo + 0.5*delta_alpha # Check if the current value of alpha_j is already sufficient - phi_alphaj = self._phi(alpha_j) + energy_alphaj = self.line_energy.at(alpha_j) + phi_alphaj = energy_alphaj.value # If the first Wolfe condition is not met replace alpha_hi # by alpha_j @@ -174,7 +171,7 @@ class LineSearchStrongWolfe(LineSearch): alpha_recent, phi_recent = alpha_hi, phi_hi alpha_hi, phi_hi = alpha_j, phi_alphaj else: - phiprime_alphaj = self._phiprime(alpha_j) + phiprime_alphaj = energy_alphaj.gradient # If the second Wolfe condition is met, return the result if abs(phiprime_alphaj) <= -c2*phiprime_0: alpha_star = alpha_j diff --git a/nifty/minimization/quasi_newton_minimizer.py b/nifty/minimization/quasi_newton_minimizer.py index d413c686596ed54013fc10c7d4707807e1053b50..284a8f135e65a7fa7d44c73d641f767a620b6638 100644 --- a/nifty/minimization/quasi_newton_minimizer.py +++ b/nifty/minimization/quasi_newton_minimizer.py @@ -26,7 +26,7 @@ class QuasiNewtonMinimizer(object, Loggable): self.line_searcher = line_searcher self.callback = callback - def __call__(self, x0, f, fprime, f_args=()): + def __call__(self, energy): """ Runs the steepest descent minimization. @@ -56,49 +56,45 @@ class QuasiNewtonMinimizer(object, Loggable): """ - x = x0.copy() - self.line_searcher.set_functions(f=f, fprime=fprime, f_args=f_args) - convergence = 0 f_k_minus_1 = None - f_k = f(x) step_length = 0 iteration_number = 1 while True: if self.callback is not None: try: - self.callback(x, f_k, iteration_number) + self.callback(energy, iteration_number) except StopIteration: self.logger.info("Minimization was stopped by callback " "function.") break - # compute the the gradient for the current x - gradient = fprime(x) + # compute the the gradient for the current location + gradient = energy.gradient gradient_norm = gradient.dot(gradient) - # check if x is at a flat point + # check if position is at a flat point if gradient_norm == 0: self.logger.info("Reached perfectly flat point. Stopping.") convergence = self.convergence_level+2 break - descend_direction = self._get_descend_direction(x, gradient) + current_position = energy.position + descend_direction = self._get_descend_direction(current_position, + gradient) - # compute the step length, which minimizes f_k along the - # search direction = the gradient - step_length, new_f_k = self.line_searcher.perform_line_search( - xk=x, + # compute the step length, which minimizes energy.value along the + # search direction + step_length, step_length = self.line_searcher.perform_line_search( + energy=energy, pk=descend_direction, - f_k=f_k, - fprime_k=gradient, f_k_minus_1=f_k_minus_1) - f_k_minus_1 = f_k - f_k = new_f_k + new_position = current_position + step_length * descend_direction + new_energy = energy.at(new_position) - # update x - x += descend_direction*step_length + f_k_minus_1 = energy.value + energy = new_energy # check convergence delta = abs(gradient).max() * (step_length/gradient_norm) @@ -127,7 +123,7 @@ class QuasiNewtonMinimizer(object, Loggable): iteration_number += 1 - return x, convergence + return energy, convergence @abc.abstractmethod def _get_descend_direction(self, gradient, gradient_norm):