Skip to content
Snippets Groups Projects
Commit ab4209bf authored by Theo Steininger's avatar Theo Steininger
Browse files

Added RelaxedNewton globally to NIFTy.

Fixed descent_direction norming.
Extended wiener_filter_hamiltonian.py <- activated plotting.
parent 2e5af433
No related branches found
No related tags found
1 merge request!50Added RelaxedNewton globally to NIFTy.
Pipeline #
from nifty import *
#import plotly.offline as pl
#import plotly.graph_objs as go
import plotly.offline as pl
import plotly.graph_objs as go
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank
np.random.seed(42)
class WienerFilterEnergy(Energy):
def __init__(self, position, D, j):
......@@ -34,6 +35,17 @@ class WienerFilterEnergy(Energy):
return_g.val = g.val.real
return return_g
@property
def curvature(self):
class Dummy(object):
def __init__(self, x):
self.x = x
def inverse_times(self, *args, **kwargs):
return self.x.times(*args, **kwargs)
my_dummy = Dummy(self.D)
return my_dummy
@memo
def D_inverse_x(self):
return D.inverse_times(self.position)
......@@ -82,14 +94,18 @@ if __name__ == "__main__":
x = energy.position
print (iteration, ((x-ss).norm()/ss.norm()).real)
minimizer = SteepestDescent(convergence_tolerance=0,
iteration_limit=50,
callback=distance_measure)
# minimizer = SteepestDescent(convergence_tolerance=0,
# iteration_limit=50,
# callback=distance_measure)
minimizer = RelaxedNewton(convergence_tolerance=0,
iteration_limit=2,
callback=distance_measure)
minimizer = VL_BFGS(convergence_tolerance=0,
iteration_limit=50,
callback=distance_measure,
max_history_length=3)
# minimizer = VL_BFGS(convergence_tolerance=0,
# iteration_limit=50,
# callback=distance_measure,
# max_history_length=3)
m0 = Field(s_space, val=1)
......@@ -97,40 +113,35 @@ if __name__ == "__main__":
(energy, convergence) = minimizer(energy)
m = energy.position
d_data = d.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
ss_data = ss.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
sh_data = sh.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
j_data = j.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=j_data)], filename='j.html')
jabs_data = np.abs(j.val.get_full_data())
jphase_data = np.angle(j.val.get_full_data())
if rank == 0:
pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
m_data = m.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=m_data)], filename='map.html')
#
#
#
# grad = gradient(m)
#
# d_data = d.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
#
#
# ss_data = ss.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
#
# sh_data = sh.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
#
# j_data = j.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=j_data)], filename='j.html')
#
# jabs_data = np.abs(j.val.get_full_data())
# jphase_data = np.angle(j.val.get_full_data())
# if rank == 0:
# pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
# pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
#
# m_data = m.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
#
# grad_data = grad.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')
......@@ -5,3 +5,4 @@ from conjugate_gradient import ConjugateGradient
from quasi_newton_minimizer import QuasiNewtonMinimizer
from steepest_descent import SteepestDescent
from vl_bfgs import VL_BFGS
from relaxed_newton import RelaxedNewton
......@@ -21,8 +21,9 @@ class RelaxedNewton(QuasiNewtonMinimizer):
gradient = energy.gradient
curvature = energy.curvature
descend_direction = curvature.inverse_times(gradient)
norm = descend_direction.norm()
if norm != 1:
return descend_direction / -norm
else:
return descend_direction * -1
return descend_direction * -1
#norm = descend_direction.norm()
# if norm != 1:
# return descend_direction / -norm
# else:
# return descend_direction * -1
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment