Skip to content
Snippets Groups Projects

Added RelaxedNewton globally to NIFTy.

Merged Theo Steininger requested to merge newton into master
3 files
+ 59
46
Compare changes
  • Side-by-side
  • Inline
Files
3
from nifty import *
#import plotly.offline as pl
#import plotly.graph_objs as go
import plotly.offline as pl
import plotly.graph_objs as go
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank
np.random.seed(42)
class WienerFilterEnergy(Energy):
def __init__(self, position, D, j):
@@ -34,6 +35,17 @@ class WienerFilterEnergy(Energy):
return_g.val = g.val.real
return return_g
@property
def curvature(self):
class Dummy(object):
def __init__(self, x):
self.x = x
def inverse_times(self, *args, **kwargs):
return self.x.times(*args, **kwargs)
my_dummy = Dummy(self.D)
return my_dummy
@memo
def D_inverse_x(self):
return D.inverse_times(self.position)
@@ -82,14 +94,18 @@ if __name__ == "__main__":
x = energy.position
print (iteration, ((x-ss).norm()/ss.norm()).real)
minimizer = SteepestDescent(convergence_tolerance=0,
iteration_limit=50,
callback=distance_measure)
# minimizer = SteepestDescent(convergence_tolerance=0,
# iteration_limit=50,
# callback=distance_measure)
minimizer = RelaxedNewton(convergence_tolerance=0,
iteration_limit=2,
callback=distance_measure)
minimizer = VL_BFGS(convergence_tolerance=0,
iteration_limit=50,
callback=distance_measure,
max_history_length=3)
# minimizer = VL_BFGS(convergence_tolerance=0,
# iteration_limit=50,
# callback=distance_measure,
# max_history_length=3)
m0 = Field(s_space, val=1)
@@ -97,40 +113,35 @@ if __name__ == "__main__":
(energy, convergence) = minimizer(energy)
m = energy.position
d_data = d.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
ss_data = ss.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
sh_data = sh.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
j_data = j.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=j_data)], filename='j.html')
jabs_data = np.abs(j.val.get_full_data())
jphase_data = np.angle(j.val.get_full_data())
if rank == 0:
pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
m_data = m.val.get_full_data().real
if rank == 0:
pl.plot([go.Heatmap(z=m_data)], filename='map.html')
#
#
#
# grad = gradient(m)
#
# d_data = d.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=d_data)], filename='data.html')
#
#
# ss_data = ss.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=ss_data)], filename='ss.html')
#
# sh_data = sh.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=sh_data)], filename='sh.html')
#
# j_data = j.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=j_data)], filename='j.html')
#
# jabs_data = np.abs(j.val.get_full_data())
# jphase_data = np.angle(j.val.get_full_data())
# if rank == 0:
# pl.plot([go.Heatmap(z=jabs_data)], filename='j_abs.html')
# pl.plot([go.Heatmap(z=jphase_data)], filename='j_phase.html')
#
# m_data = m.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=m_data)], filename='map.html')
#
# grad_data = grad.val.get_full_data().real
# if rank == 0:
# pl.plot([go.Heatmap(z=grad_data)], filename='grad.html')
Loading