From 108c19bc4719d183bda56288e056b32cba16342c Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Mon, 13 Aug 2018 12:27:16 +0200 Subject: [PATCH] adjust for redesign branch --- demos/demo.py | 2 +- starblade/starblade_energy.py | 8 ++++---- starblade/starblade_kl.py | 34 ++++++++++++++++++++++++++++++++-- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/demos/demo.py b/demos/demo.py index 77743dd..d5bfa19 100644 --- a/demos/demo.py +++ b/demos/demo.py @@ -45,7 +45,7 @@ def generate_mock_data(): if __name__ == '__main__': np.random.seed(42) data = generate_mock_data() - data = np.load('sky.npy') + #data = np.load('sky.npy') myStarblade = sb.build_starblade(data=data, alpha=1.5, cg_steps=5, q=1e-3) for i in range(3): # not fully converged after 3 steps. diff --git a/starblade/starblade_energy.py b/starblade/starblade_energy.py index 442d5ec..824dcbf 100644 --- a/starblade/starblade_energy.py +++ b/starblade/starblade_energy.py @@ -56,10 +56,10 @@ class StarbladeEnergy(ift.Energy): self.FFT = parameters['FFT'] self.correlation = ift.create_power_operator(self.FFT.domain, parameters['power_spectrum']) self.alpham1 = parameters['alpha'] - 1. - pos_tanh = ift.PositiveTanh() self.S = ift.SandwichOperator.make(self.FFT.adjoint, self.correlation) - self.a = pos_tanh(self.position) - self.a_p = pos_tanh.derivative(self.position) + tmp = ift.tanh(position) + self.a = 0.5*(1.+tmp) + self.a_p = 0.5*(1.-tmp**2) self.a_pp = 2. - 4.*self.a da = parameters['data']*self.a self.u = ift.log(da) @@ -105,6 +105,6 @@ class StarbladeEnergy(ift.Energy): O_x = ift.ScalingOperator(1./self.var_x, self.position.domain) N_inv = ift.DiagonalOperator(point) S_p = ift.DiagonalOperator(self.s_p) - my_S_inv = ift.SandwichOperator.make(self.FFT.inverse*S_p, self.correlation.inverse) + my_S_inv = ift.SandwichOperator.make(self.FFT.inverse(S_p), self.correlation.inverse) curv = ift.InversionEnabler(ift.SamplingEnabler(my_S_inv+N_inv, O_x, self.parameters['sampling_controller']), self.parameters['controller']) return curv diff --git a/starblade/starblade_kl.py b/starblade/starblade_kl.py index a6a90fe..bb05a7c 100644 --- a/starblade/starblade_kl.py +++ b/starblade/starblade_kl.py @@ -17,9 +17,39 @@ # Starblade is being developed at the Max-Planck-Institut fuer Astrophysik import nifty5 as ift +from nifty5.utilities import my_sum +class SampledKullbachLeiblerDivergence(ift.Energy): + def __init__(self, h, res_samples): + super(SampledKullbachLeiblerDivergence, self).__init__(h.position) + self._h = h + self._res_samples = res_samples -class StarbladeKL(ift.SampledKullbachLeiblerDivergence): + self._energy_list = tuple(h.at(self.position+ss) + for ss in res_samples) + + def at(self, position): + return self.__class__(self._h.at(position), self._res_samples) + + @property + @ift.memo + def value(self): + return (my_sum(map(lambda v: v.value, self._energy_list)) / + len(self._energy_list)) + + @property + @ift.memo + def gradient(self): + return (my_sum(map(lambda v: v.gradient, self._energy_list)) / + len(self._energy_list)) + + @property + @ift.memo + def metric(self): + return (my_sum(map(lambda v: v.metric, self._energy_list)).scale + (1./len(self._energy_list))) + +class StarbladeKL(SampledKullbachLeiblerDivergence): """The Kullback-Leibler divergence for the starblade problem. Parameters @@ -39,5 +69,5 @@ class StarbladeKL(ift.SampledKullbachLeiblerDivergence): @property def metric(self): - metric = ift.SampledKullbachLeiblerDivergence.metric.fget(self) + metric = SampledKullbachLeiblerDivergence.metric.fget(self) return ift.InversionEnabler(metric, self.parameters['controller']) -- GitLab