From ab4209bf7243712cdf4d91d5caa8014da0ea7f1f Mon Sep 17 00:00:00 2001
From: Theo Steininger <theos@mpa-garching.mpg.de>
Date: Thu, 9 Feb 2017 11:41:06 +0100
Subject: [PATCH] Added RelaxedNewton globally to NIFTy. Fixed
 descent_direction norming. Extended wiener_filter_hamiltonian.py <- activated
 plotting.

---
 demos/wiener_filter_hamiltonian.py   | 93 ++++++++++++++++------------
 nifty/minimization/__init__.py       |  1 +
 nifty/minimization/relaxed_newton.py | 11 ++--
 3 files changed, 59 insertions(+), 46 deletions(-)

diff --git a/demos/wiener_filter_hamiltonian.py b/demos/wiener_filter_hamiltonian.py
index 97446759d..3b7adf40c 100644
--- a/demos/wiener_filter_hamiltonian.py
+++ b/demos/wiener_filter_hamiltonian.py
@@ -1,13 +1,14 @@
 
 from nifty import *
 
-#import plotly.offline as pl
-#import plotly.graph_objs as go
+import plotly.offline as pl
+import plotly.graph_objs as go
 
 from mpi4py import MPI
 comm = MPI.COMM_WORLD
 rank = comm.rank
 
+np.random.seed(42)
 
 class WienerFilterEnergy(Energy):
     def __init__(self, position, D, j):
@@ -34,6 +35,17 @@ class WienerFilterEnergy(Energy):
         return_g.val = g.val.real
         return return_g
 
+    @property
+    def curvature(self):
+        class Dummy(object):
+            def __init__(self, x):
+                self.x = x
+            def inverse_times(self, *args, **kwargs):
+                return self.x.times(*args, **kwargs)
+        my_dummy = Dummy(self.D)
+        return my_dummy
+
+
     @memo
     def D_inverse_x(self):
         return D.inverse_times(self.position)
@@ -82,14 +94,18 @@ if __name__ == "__main__":
         x = energy.position
         print (iteration, ((x-ss).norm()/ss.norm()).real)
 
-    minimizer = SteepestDescent(convergence_tolerance=0,
-                                iteration_limit=50,
-                                callback=distance_measure)
+#    minimizer = SteepestDescent(convergence_tolerance=0,
+#                                iteration_limit=50,
+#                                callback=distance_measure)
+
+    minimizer = RelaxedNewton(convergence_tolerance=0,
+                              iteration_limit=2,
+                              callback=distance_measure)
 
-    minimizer = VL_BFGS(convergence_tolerance=0,
-                        iteration_limit=50,
-                        callback=distance_measure,
-                        max_history_length=3)
+#    minimizer = VL_BFGS(convergence_tolerance=0,
+#                        iteration_limit=50,
+#                        callback=distance_measure,
+#                        max_history_length=3)
 
     m0 = Field(s_space, val=1)
 
@@ -97,40 +113,35 @@ if __name__ == "__main__":
 
     (energy, convergence) = minimizer(energy)
 
+    m = energy.position
 
+    d_data = d.val.get_full_data().real
+    if rank == 0:
+        pl.plot([go.Heatmap(z=d_data)], filename='data.html')
+
+
+    ss_data = ss.val.get_full_data().real
+    if rank == 0:
+        pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
+
+    sh_data = sh.val.get_full_data().real
+    if rank == 0:
+        pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
+
+    j_data = j.val.get_full_data().real
+    if rank == 0:
+        pl.plot([go.Heatmap(z=j_data)], filename='j.html')
+
+    jabs_data = np.abs(j.val.get_full_data())
+    jphase_data = np.angle(j.val.get_full_data())
+    if rank == 0:
+        pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
+        pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
+
+    m_data = m.val.get_full_data().real
+    if rank == 0:
+        pl.plot([go.Heatmap(z=m_data)], filename='map.html')
 
-#
-#
-#
-#    grad = gradient(m)
-#
-#    d_data = d.val.get_full_data().real
-#    if rank == 0:
-#        pl.plot([go.Heatmap(z=d_data)], filename='data.html')
-#
-#
-#    ss_data = ss.val.get_full_data().real
-#    if rank == 0:
-#        pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
-#
-#    sh_data = sh.val.get_full_data().real
-#    if rank == 0:
-#        pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
-#
-#    j_data = j.val.get_full_data().real
-#    if rank == 0:
-#        pl.plot([go.Heatmap(z=j_data)], filename='j.html')
-#
-#    jabs_data = np.abs(j.val.get_full_data())
-#    jphase_data = np.angle(j.val.get_full_data())
-#    if rank == 0:
-#        pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
-#        pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
-#
-#    m_data = m.val.get_full_data().real
-#    if rank == 0:
-#        pl.plot([go.Heatmap(z=m_data)], filename='map.html')
-#
 #    grad_data = grad.val.get_full_data().real
 #    if rank == 0:
 #        pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')
diff --git a/nifty/minimization/__init__.py b/nifty/minimization/__init__.py
index 7a9464703..2d31c55c6 100644
--- a/nifty/minimization/__init__.py
+++ b/nifty/minimization/__init__.py
@@ -5,3 +5,4 @@ from conjugate_gradient import ConjugateGradient
 from quasi_newton_minimizer import QuasiNewtonMinimizer
 from steepest_descent import SteepestDescent
 from vl_bfgs import VL_BFGS
+from relaxed_newton import RelaxedNewton
diff --git a/nifty/minimization/relaxed_newton.py b/nifty/minimization/relaxed_newton.py
index b42016055..0dc5f32d8 100644
--- a/nifty/minimization/relaxed_newton.py
+++ b/nifty/minimization/relaxed_newton.py
@@ -21,8 +21,9 @@ class RelaxedNewton(QuasiNewtonMinimizer):
         gradient = energy.gradient
         curvature = energy.curvature
         descend_direction = curvature.inverse_times(gradient)
-        norm = descend_direction.norm()
-        if norm != 1:
-            return descend_direction / -norm
-        else:
-            return descend_direction * -1
+        return descend_direction * -1
+        #norm = descend_direction.norm()
+#        if norm != 1:
+#            return descend_direction / -norm
+#        else:
+#            return descend_direction * -1
-- 
GitLab