diff --git a/demos/demo.py b/demos/demo.py index 77743dd20c2bc9f5dc791625d331cc1b137903a1..d5bfa19c78695ed468f3f5a87e6191bfec7622ed 100644 --- a/demos/demo.py +++ b/demos/demo.py @@ -45,7 +45,7 @@ def generate_mock_data(): if __name__ == '__main__': np.random.seed(42) data = generate_mock_data() - data = np.load('sky.npy') + #data = np.load('sky.npy') myStarblade = sb.build_starblade(data=data, alpha=1.5, cg_steps=5, q=1e-3) for i in range(3): # not fully converged after 3 steps. diff --git a/starblade/starblade_energy.py b/starblade/starblade_energy.py index 442d5ec33dbad00a5adbd507881517951a558a48..824dcbf6b7253157be8f428d8bb0363d6eb982ad 100644 --- a/starblade/starblade_energy.py +++ b/starblade/starblade_energy.py @@ -56,10 +56,10 @@ class StarbladeEnergy(ift.Energy): self.FFT = parameters['FFT'] self.correlation = ift.create_power_operator(self.FFT.domain, parameters['power_spectrum']) self.alpham1 = parameters['alpha'] - 1. - pos_tanh = ift.PositiveTanh() self.S = ift.SandwichOperator.make(self.FFT.adjoint, self.correlation) - self.a = pos_tanh(self.position) - self.a_p = pos_tanh.derivative(self.position) + tmp = ift.tanh(position) + self.a = 0.5*(1.+tmp) + self.a_p = 0.5*(1.-tmp**2) self.a_pp = 2. - 4.*self.a da = parameters['data']*self.a self.u = ift.log(da) @@ -105,6 +105,6 @@ class StarbladeEnergy(ift.Energy): O_x = ift.ScalingOperator(1./self.var_x, self.position.domain) N_inv = ift.DiagonalOperator(point) S_p = ift.DiagonalOperator(self.s_p) - my_S_inv = ift.SandwichOperator.make(self.FFT.inverse*S_p, self.correlation.inverse) + my_S_inv = ift.SandwichOperator.make(self.FFT.inverse(S_p), self.correlation.inverse) curv = ift.InversionEnabler(ift.SamplingEnabler(my_S_inv+N_inv, O_x, self.parameters['sampling_controller']), self.parameters['controller']) return curv diff --git a/starblade/starblade_kl.py b/starblade/starblade_kl.py index a6a90feb1740a5a0bff748d6dffba2aa049ab5e3..bb05a7c96a4c7819a407d77308831eb6d3111e2b 100644 --- a/starblade/starblade_kl.py +++ b/starblade/starblade_kl.py @@ -17,9 +17,39 @@ # Starblade is being developed at the Max-Planck-Institut fuer Astrophysik import nifty5 as ift +from nifty5.utilities import my_sum +class SampledKullbachLeiblerDivergence(ift.Energy): + def __init__(self, h, res_samples): + super(SampledKullbachLeiblerDivergence, self).__init__(h.position) + self._h = h + self._res_samples = res_samples -class StarbladeKL(ift.SampledKullbachLeiblerDivergence): + self._energy_list = tuple(h.at(self.position+ss) + for ss in res_samples) + + def at(self, position): + return self.__class__(self._h.at(position), self._res_samples) + + @property + @ift.memo + def value(self): + return (my_sum(map(lambda v: v.value, self._energy_list)) / + len(self._energy_list)) + + @property + @ift.memo + def gradient(self): + return (my_sum(map(lambda v: v.gradient, self._energy_list)) / + len(self._energy_list)) + + @property + @ift.memo + def metric(self): + return (my_sum(map(lambda v: v.metric, self._energy_list)).scale + (1./len(self._energy_list))) + +class StarbladeKL(SampledKullbachLeiblerDivergence): """The Kullback-Leibler divergence for the starblade problem. Parameters @@ -39,5 +69,5 @@ class StarbladeKL(ift.SampledKullbachLeiblerDivergence): @property def metric(self): - metric = ift.SampledKullbachLeiblerDivergence.metric.fget(self) + metric = SampledKullbachLeiblerDivergence.metric.fget(self) return ift.InversionEnabler(metric, self.parameters['controller'])