Commit 108c19bc authored by Martin Reinecke's avatar Martin Reinecke

adjust for redesign branch

parent b823fe58
......@@ -45,7 +45,7 @@ def generate_mock_data():
if __name__ == '__main__':
np.random.seed(42)
data = generate_mock_data()
data = np.load('sky.npy')
#data = np.load('sky.npy')
myStarblade = sb.build_starblade(data=data, alpha=1.5, cg_steps=5, q=1e-3)
for i in range(3): # not fully converged after 3 steps.
......
......@@ -56,10 +56,10 @@ class StarbladeEnergy(ift.Energy):
self.FFT = parameters['FFT']
self.correlation = ift.create_power_operator(self.FFT.domain, parameters['power_spectrum'])
self.alpham1 = parameters['alpha'] - 1.
pos_tanh = ift.PositiveTanh()
self.S = ift.SandwichOperator.make(self.FFT.adjoint, self.correlation)
self.a = pos_tanh(self.position)
self.a_p = pos_tanh.derivative(self.position)
tmp = ift.tanh(position)
self.a = 0.5*(1.+tmp)
self.a_p = 0.5*(1.-tmp**2)
self.a_pp = 2. - 4.*self.a
da = parameters['data']*self.a
self.u = ift.log(da)
......@@ -105,6 +105,6 @@ class StarbladeEnergy(ift.Energy):
O_x = ift.ScalingOperator(1./self.var_x, self.position.domain)
N_inv = ift.DiagonalOperator(point)
S_p = ift.DiagonalOperator(self.s_p)
my_S_inv = ift.SandwichOperator.make(self.FFT.inverse*S_p, self.correlation.inverse)
my_S_inv = ift.SandwichOperator.make(self.FFT.inverse(S_p), self.correlation.inverse)
curv = ift.InversionEnabler(ift.SamplingEnabler(my_S_inv+N_inv, O_x, self.parameters['sampling_controller']), self.parameters['controller'])
return curv
......@@ -17,9 +17,39 @@
# Starblade is being developed at the Max-Planck-Institut fuer Astrophysik
import nifty5 as ift
from nifty5.utilities import my_sum
class SampledKullbachLeiblerDivergence(ift.Energy):
def __init__(self, h, res_samples):
super(SampledKullbachLeiblerDivergence, self).__init__(h.position)
self._h = h
self._res_samples = res_samples
class StarbladeKL(ift.SampledKullbachLeiblerDivergence):
self._energy_list = tuple(h.at(self.position+ss)
for ss in res_samples)
def at(self, position):
return self.__class__(self._h.at(position), self._res_samples)
@property
@ift.memo
def value(self):
return (my_sum(map(lambda v: v.value, self._energy_list)) /
len(self._energy_list))
@property
@ift.memo
def gradient(self):
return (my_sum(map(lambda v: v.gradient, self._energy_list)) /
len(self._energy_list))
@property
@ift.memo
def metric(self):
return (my_sum(map(lambda v: v.metric, self._energy_list)).scale
(1./len(self._energy_list)))
class StarbladeKL(SampledKullbachLeiblerDivergence):
"""The Kullback-Leibler divergence for the starblade problem.
Parameters
......@@ -39,5 +69,5 @@ class StarbladeKL(ift.SampledKullbachLeiblerDivergence):
@property
def metric(self):
metric = ift.SampledKullbachLeiblerDivergence.metric.fget(self)
metric = SampledKullbachLeiblerDivergence.metric.fget(self)
return ift.InversionEnabler(metric, self.parameters['controller'])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment