Commit 7b98c3bb authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'line_search' of gitlab.mpcdf.mpg.de:ift/NIFTy into line_search

parents b5f2473f 41309d1f
from nifty import *
from nifty.library.wiener_filter import WienerFilterEnergy
import numpy as np
from nifty import (VL_BFGS, DiagonalOperator, FFTOperator, Field,
LinearOperator, PowerSpace, RelaxedNewton, RGSpace,
SteepestDescent, create_power_operator, exp, log, sqrt)
from nifty.library.critical_filter import CriticalPowerEnergy
import plotly.offline as pl
import plotly.graph_objs as go
from nifty.library.wiener_filter import WienerFilterEnergy
import plotly.graph_objs as go
import plotly.offline as pl
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank
np.random.seed(42)
def plot_parameters(m,t,p, p_d):
def plot_parameters(m, t, p, p_d):
x = log(t.domain[0].kindex)
m = fft.adjoint_times(m)
......@@ -20,7 +24,8 @@ def plot_parameters(m,t,p, p_d):
p = p.val.get_full_data().real
p_d = p_d.val.get_full_data().real
pl.plot([go.Heatmap(z=m)], filename='map.html')
pl.plot([go.Scatter(x=x,y=t), go.Scatter(x=x ,y=p), go.Scatter(x=x, y=p_d)], filename="t.html")
pl.plot([go.Scatter(x=x, y=t), go.Scatter(x=x, y=p),
go.Scatter(x=x, y=p_d)], filename="t.html")
class AdjointFFTResponse(LinearOperator):
......@@ -36,6 +41,7 @@ class AdjointFFTResponse(LinearOperator):
def _adjoint_times(self, x, spaces=None):
return self.FFT(self.R.adjoint_times(x))
@property
def domain(self):
return self._domain
......@@ -48,41 +54,40 @@ class AdjointFFTResponse(LinearOperator):
def unitary(self):
return False
if __name__ == "__main__":
distribution_strategy = 'not'
# Set up position space
s_space = RGSpace([128,128])
s_space = RGSpace([128, 128])
# s_space = HPSpace(32)
# Define harmonic transformation and associated harmonic space
fft = FFTOperator(s_space)
h_space = fft.target[0]
# Setting up power space
# Set up power space
p_space = PowerSpace(h_space, logarithmic=True,
distribution_strategy=distribution_strategy)
# Choosing the prior correlation structure and defining correlation operator
# Choose the prior correlation structure and defining correlation operator
p_spec = (lambda k: (.5 / (k + 1) ** 3))
S = create_power_operator(h_space, power_spectrum=p_spec,
distribution_strategy=distribution_strategy)
# Drawing a sample sh from the prior distribution in harmonic space
# Draw a sample sh from the prior distribution in harmonic space
sp = Field(p_space, val=p_spec,
distribution_strategy=distribution_strategy)
sh = sp.power_synthesize(real_signal=True)
# Choosing the measurement instrument
# Choose the measurement instrument
# Instrument = SmoothingOperator(s_space, sigma=0.01)
Instrument = DiagonalOperator(s_space, diagonal=1.)
# Instrument._diagonal.val[200:400, 200:400] = 0
#Instrument._diagonal.val[64:512-64, 64:512-64] = 0
# Instrument._diagonal.val[64:512-64, 64:512-64] = 0
#Adding a harmonic transformation to the instrument
# Add a harmonic transformation to the instrument
R = AdjointFFTResponse(fft, Instrument)
noise = 1.
......@@ -92,7 +97,7 @@ if __name__ == "__main__":
std=sqrt(noise),
mean=0)
# Creating the mock data
# Create mock data
d = R(sh) + n
# The information source
......@@ -103,56 +108,49 @@ if __name__ == "__main__":
if rank == 0:
pl.plot([go.Heatmap(z=d_data)], filename='data.html')
# minimization strategy
def convergence_measure(a_energy, iteration): # returns current energy
# Minimization strategy
def convergence_measure(a_energy, iteration): # returns current energy
x = a_energy.value
print (x, iteration)
print(x, iteration)
minimizer1 = RelaxedNewton(convergence_tolerance=1e-8,
convergence_level=1,
iteration_limit=5,
callback=convergence_measure)
convergence_level=1,
iteration_limit=5,
callback=convergence_measure)
minimizer2 = VL_BFGS(convergence_tolerance=1e-8,
convergence_level=1,
iteration_limit=1000,
callback=convergence_measure,
max_history_length=20)
convergence_level=1,
iteration_limit=1000,
callback=convergence_measure,
max_history_length=20)
minimizer3 = SteepestDescent(convergence_tolerance=1e-8,
iteration_limit=500,
callback=convergence_measure)
iteration_limit=500,
callback=convergence_measure)
# Setting starting position
flat_power = Field(p_space,val=1e-8)
# Set starting position
flat_power = Field(p_space, val=1e-8)
m0 = flat_power.power_synthesize(real_signal=True)
t0 = Field(p_space, val=log(1./(1+p_space.kindex)**2))
for i in range(500):
for i in range(50):
S0 = create_power_operator(h_space, power_spectrum=exp(t0),
distribution_strategy=distribution_strategy)
distribution_strategy=distribution_strategy)
# Initializing the nonlinear Wiener Filter energy
# Initialize non-linear Wiener Filter energy
map_energy = WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0)
# Solving the Wiener Filter analytically
# Solve the Wiener Filter analytically
D0 = map_energy.curvature
m0 = D0.inverse_times(j)
# Initializing the power energy with updated parameters
power_energy = CriticalPowerEnergy(position=t0, m=m0, D=D0, smoothness_prior=10., samples=3)
# Initialize power energy with updated parameters
power_energy = CriticalPowerEnergy(position=t0, m=m0, D=D0,
smoothness_prior=10., samples=3)
(power_energy, convergence) = minimizer2(power_energy)
# Set new power spectrum
t0.val = power_energy.position.val.real
# Setting new power spectrum
t0.val = power_energy.position.val.real
# Plotting current estimate
print i
if i%50 == 0:
plot_parameters(m0,t0,log(sp), data_power)
# Plot current estimate
print(i)
if i % 5 == 0:
plot_parameters(m0, t0, log(sp), data_power)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment