Commit 81d4481b authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

some changes, nothing serious, not working

parent b400978c
......@@ -10,20 +10,21 @@ from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.rank
np.random.seed(42)
np.random.seed(43)
def plot_parameters(m, t, p, p_d):
def plot_parameters(m, t, p, p_sig,p_d):
x = np.log(t.domain[0].kindex)
m = fft.adjoint_times(m)
m = m.val.get_full_data().real
t = t.val.get_full_data().real
p = p.val.get_full_data().real
pd_sig = p_sig.val.get_full_data()
p_d = p_d.val.get_full_data().real
pl.plot([go.Heatmap(z=m)], filename='map.html', auto_open=False)
pl.plot([go.Scatter(x=x, y=t), go.Scatter(x=x, y=p),
go.Scatter(x=x, y=p_d)], filename="t.html", auto_open=False)
go.Scatter(x=x, y=p_d),go.Scatter(x=x, y=pd_sig)], filename="t.html", auto_open=False)
class AdjointFFTResponse(ift.LinearOperator):
......@@ -58,7 +59,8 @@ if __name__ == "__main__":
distribution_strategy = 'not'
# Set up position space
s_space = ift.RGSpace([128, 128])
dist = 1/128. *0.1
s_space = ift.RGSpace([128, 128], distances=[dist,dist])
# s_space = ift.HPSpace(32)
# Define harmonic transformation and associated harmonic space
......@@ -72,7 +74,8 @@ if __name__ == "__main__":
distribution_strategy=distribution_strategy)
# Choose the prior correlation structure and defining correlation operator
p_spec = (lambda k: (.5 / (k + 1) ** 3))
# p_spec = (lambda k: (.5 / (k + 1) ** 3))
p_spec = (lambda k: 1)
S = ift.create_power_operator(h_space, power_spectrum=p_spec,
distribution_strategy=distribution_strategy)
......@@ -123,7 +126,6 @@ if __name__ == "__main__":
IC3 = ift.GradientNormController(iteration_limit=100,
tol_abs_gradnorm=0.1)
minimizer3 = ift.SteepestDescent(IC3)
# Set starting position
flat_power = ift.Field(p_space, val=1e-8)
m0 = flat_power.power_synthesize(real_signal=True)
......@@ -137,11 +139,11 @@ if __name__ == "__main__":
# Initialize non-linear Wiener Filter energy
map_energy = WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S0)
# Solve the Wiener Filter analytically
D0 = map_energy.curvature
m0 = D0.inverse_times(j)
# D0 = map_energy.curvature
# m0 = D0.inverse_times(j)
# Initialize power energy with updated parameters
power_energy = CriticalPowerEnergy(position=t0, m=m0, D=D0,
smoothness_prior=10., samples=3)
power_energy = CriticalPowerEnergy(position=t0, m=sh, D=None,
smoothness_prior=1e-15, samples=3)
(power_energy, convergence) = minimizer2(power_energy)
......@@ -150,5 +152,6 @@ if __name__ == "__main__":
# Plot current estimate
print(i)
if i % 5 == 0:
plot_parameters(m0, t0, ift.log(sp), data_power)
if i % 1 == 0:
plot_parameters(sh, t0, ift.log(sp), ift.log(sh.power_analyze(binbounds=p_space.binbounds)),data_power)
print ift.log(sh.power_analyze(binbounds=p_space.binbounds)).val - t0.val
......@@ -73,10 +73,11 @@ class CriticalPowerEnergy(Energy):
self._w = w if w is not None else None
if inverter is None:
preconditioner = DiagonalOperator(self._theta.domain,
diagonal=self._theta.weight(-1),
diagonal=self._theta,
copy=False)
inverter = ConjugateGradient(preconditioner=preconditioner)
self._inverter = inverter
self.one = Field(self.position.domain,val=1.)
@property
def inverter(self):
......@@ -94,16 +95,16 @@ class CriticalPowerEnergy(Energy):
@property
@memo
def value(self):
energy = self._theta.sum()
energy += self.position.weight(-1).vdot(self._rho_prime)
energy = self.one.vdot(self._theta)
energy += self.position.vdot(self.one/2.)
energy += 0.5 * self.position.vdot(self._Tt)
return energy.real
@property
@memo
def gradient(self):
gradient = -self._theta.weight(-1)
gradient += (self._rho_prime).weight(-1)
gradient = -self._theta
gradient += (self.one/2.)
gradient += self._Tt
gradient.val = gradient.val.real
return gradient
......@@ -111,7 +112,7 @@ class CriticalPowerEnergy(Energy):
@property
@memo
def curvature(self):
return CriticalPowerCurvature(theta=self._theta.weight(-1), T=self.T,
return CriticalPowerCurvature(theta=self._theta, T=self.T,
inverter=self.inverter)
# ---Added properties and methods---
......@@ -142,7 +143,7 @@ class CriticalPowerEnergy(Energy):
w = self.m.power_analyze(
binbounds=self.position.domain[0].binbounds)
w *= self.rho
self._w = w
self._w = w.weight(-1)
return self._w
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment