From ab4209bf7243712cdf4d91d5caa8014da0ea7f1f Mon Sep 17 00:00:00 2001 From: Theo Steininger <theos@mpa-garching.mpg.de> Date: Thu, 9 Feb 2017 11:41:06 +0100 Subject: [PATCH] Added RelaxedNewton globally to NIFTy. Fixed descent_direction norming. Extended wiener_filter_hamiltonian.py <- activated plotting. --- demos/wiener_filter_hamiltonian.py | 93 ++++++++++++++++------------ nifty/minimization/__init__.py | 1 + nifty/minimization/relaxed_newton.py | 11 ++-- 3 files changed, 59 insertions(+), 46 deletions(-) diff --git a/demos/wiener_filter_hamiltonian.py b/demos/wiener_filter_hamiltonian.py index 97446759d..3b7adf40c 100644 --- a/demos/wiener_filter_hamiltonian.py +++ b/demos/wiener_filter_hamiltonian.py @@ -1,13 +1,14 @@ from nifty import * -#import plotly.offline as pl -#import plotly.graph_objs as go +import plotly.offline as pl +import plotly.graph_objs as go from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.rank +np.random.seed(42) class WienerFilterEnergy(Energy): def __init__(self, position, D, j): @@ -34,6 +35,17 @@ class WienerFilterEnergy(Energy): return_g.val = g.val.real return return_g + @property + def curvature(self): + class Dummy(object): + def __init__(self, x): + self.x = x + def inverse_times(self, *args, **kwargs): + return self.x.times(*args, **kwargs) + my_dummy = Dummy(self.D) + return my_dummy + + @memo def D_inverse_x(self): return D.inverse_times(self.position) @@ -82,14 +94,18 @@ if __name__ == "__main__": x = energy.position print (iteration, ((x-ss).norm()/ss.norm()).real) - minimizer = SteepestDescent(convergence_tolerance=0, - iteration_limit=50, - callback=distance_measure) +# minimizer = SteepestDescent(convergence_tolerance=0, +# iteration_limit=50, +# callback=distance_measure) + + minimizer = RelaxedNewton(convergence_tolerance=0, + iteration_limit=2, + callback=distance_measure) - minimizer = VL_BFGS(convergence_tolerance=0, - iteration_limit=50, - callback=distance_measure, - max_history_length=3) +# minimizer = VL_BFGS(convergence_tolerance=0, +# iteration_limit=50, +# callback=distance_measure, +# max_history_length=3) m0 = Field(s_space, val=1) @@ -97,40 +113,35 @@ if __name__ == "__main__": (energy, convergence) = minimizer(energy) + m = energy.position + d_data = d.val.get_full_data().real + if rank == 0: + pl.plot([go.Heatmap(z=d_data)], filename='data.html') + + + ss_data = ss.val.get_full_data().real + if rank == 0: + pl.plot([go.Heatmap(z=ss_data)], filename='ss.html') + + sh_data = sh.val.get_full_data().real + if rank == 0: + pl.plot([go.Heatmap(z=sh_data)], filename='sh.html') + + j_data = j.val.get_full_data().real + if rank == 0: + pl.plot([go.Heatmap(z=j_data)], filename='j.html') + + jabs_data = np.abs(j.val.get_full_data()) + jphase_data = np.angle(j.val.get_full_data()) + if rank == 0: + pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html') + pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html') + + m_data = m.val.get_full_data().real + if rank == 0: + pl.plot([go.Heatmap(z=m_data)], filename='map.html') -# -# -# -# grad = gradient(m) -# -# d_data = d.val.get_full_data().real -# if rank == 0: -# pl.plot([go.Heatmap(z=d_data)], filename='data.html') -# -# -# ss_data = ss.val.get_full_data().real -# if rank == 0: -# pl.plot([go.Heatmap(z=ss_data)], filename='ss.html') -# -# sh_data = sh.val.get_full_data().real -# if rank == 0: -# pl.plot([go.Heatmap(z=sh_data)], filename='sh.html') -# -# j_data = j.val.get_full_data().real -# if rank == 0: -# pl.plot([go.Heatmap(z=j_data)], filename='j.html') -# -# jabs_data = np.abs(j.val.get_full_data()) -# jphase_data = np.angle(j.val.get_full_data()) -# if rank == 0: -# pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html') -# pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html') -# -# m_data = m.val.get_full_data().real -# if rank == 0: -# pl.plot([go.Heatmap(z=m_data)], filename='map.html') -# # grad_data = grad.val.get_full_data().real # if rank == 0: # pl.plot([go.Heatmap(z=grad_data)], filename='grad.html') diff --git a/nifty/minimization/__init__.py b/nifty/minimization/__init__.py index 7a9464703..2d31c55c6 100644 --- a/nifty/minimization/__init__.py +++ b/nifty/minimization/__init__.py @@ -5,3 +5,4 @@ from conjugate_gradient import ConjugateGradient from quasi_newton_minimizer import QuasiNewtonMinimizer from steepest_descent import SteepestDescent from vl_bfgs import VL_BFGS +from relaxed_newton import RelaxedNewton diff --git a/nifty/minimization/relaxed_newton.py b/nifty/minimization/relaxed_newton.py index b42016055..0dc5f32d8 100644 --- a/nifty/minimization/relaxed_newton.py +++ b/nifty/minimization/relaxed_newton.py @@ -21,8 +21,9 @@ class RelaxedNewton(QuasiNewtonMinimizer): gradient = energy.gradient curvature = energy.curvature descend_direction = curvature.inverse_times(gradient) - norm = descend_direction.norm() - if norm != 1: - return descend_direction / -norm - else: - return descend_direction * -1 + return descend_direction * -1 + #norm = descend_direction.norm() +# if norm != 1: +# return descend_direction / -norm +# else: +# return descend_direction * -1 -- GitLab