From c0f6f64a36f5ce4c5c2f58f9c379232394c79bca Mon Sep 17 00:00:00 2001 From: "Knollmueller, Jakob (kjako)" <jakob@knollmueller.de> Date: Tue, 3 Jul 2018 16:39:18 +0200 Subject: [PATCH] added BernoulliEnergy --- demos/bernoulli_demo.py | 75 ++++++++++++++++++++++++++++++ nifty5/__init__.py | 1 - nifty5/library/__init__.py | 1 + nifty5/library/bernoulli_energy.py | 60 ++++++++++++++++++++++++ 4 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 demos/bernoulli_demo.py create mode 100644 nifty5/library/bernoulli_energy.py diff --git a/demos/bernoulli_demo.py b/demos/bernoulli_demo.py new file mode 100644 index 000000000..830db17d3 --- /dev/null +++ b/demos/bernoulli_demo.py @@ -0,0 +1,75 @@ +import nifty5 as ift +import numpy as np + + + + +if __name__ == '__main__': + # ABOUT THIS CODE + np.random.seed(41) + + # Set up the position space of the signal + # + # # One dimensional regular grid with uniform exposure + # position_space = ift.RGSpace(1024) + # exposure = np.ones(position_space.shape) + + # Two-dimensional regular grid with inhomogeneous exposure + position_space = ift.RGSpace([512, 512]) + + # # Sphere with with uniform exposure + # position_space = ift.HPSpace(128) + # exposure = ift.Field.full(position_space, 1.) + + # Defining harmonic space and transform + harmonic_space = position_space.get_default_codomain() + HT = ift.HarmonicTransformOperator(harmonic_space, position_space) + + domain = ift.MultiDomain.make({'xi': harmonic_space}) + position = ift.from_random('normal', domain) + + # Define power spectrum and amplitudes + def sqrtpspec(k): + return 1. / (20. + k**2) + + p_space = ift.PowerSpace(harmonic_space) + pd = ift.PowerDistributor(harmonic_space, p_space) + a = ift.PS_field(p_space, sqrtpspec) + A = pd(a) + + # Set up a sky model + xi = ift.Variable(position)['xi'] + logsky_h = xi * A + logsky = HT(logsky_h) + sky = ift.PointwisePositiveTanh(logsky) + + GR = ift.GeometryRemover(position_space) + # Set up instrumental response + R = GR + + # Generate mock data + d_space = R.target[0] + p = R(sky) + mock_position = ift.from_random('normal', p.position.domain) + pp = p.at(mock_position).value + data = np.random.binomial(1,pp.to_global_data().astype(np.float64)) + data = ift.Field.from_global_data(d_space, data) + + # Compute likelihood and Hamiltonian + position = ift.from_random('normal', p.position.domain) + likelihood = ift.BernoulliEnergy(p, data) + ic_cg = ift.GradientNormController(iteration_limit=50) + ic_newton = ift.GradientNormController(name='Newton', iteration_limit=30, + tol_abs_gradnorm=1e-3) + minimizer = ift.RelaxedNewton(ic_newton) + ic_sampling = ift.GradientNormController(iteration_limit=100) + + # Minimize the Hamiltonian + H = ift.Hamiltonian(likelihood, ic_sampling) + H = H.makeInvertible(ic_cg) + # minimizer = ift.SteepestDescent(ic_newton) + H, convergence = minimizer(H) + + # result_sky = sky.at(H.position).value + # ift.plot(result_sky) + diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 0e1b17e39..fe63c27fb 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -6,7 +6,6 @@ from .domain_tuple import DomainTuple from .field import Field from .nonlinearities import Exponential, Linear, PositiveTanh, Tanh - from .models import * from .operators import * from .probing.utils import probe_with_posterior_samples, probe_diagonal, \ diff --git a/nifty5/library/__init__.py b/nifty5/library/__init__.py index 2678799c9..9c8962d50 100644 --- a/nifty5/library/__init__.py +++ b/nifty5/library/__init__.py @@ -7,3 +7,4 @@ from .poissonian_energy import PoissonianEnergy from .wiener_filter_curvature import WienerFilterCurvature from .wiener_filter_energy import WienerFilterEnergy from .correlated_fields import make_correlated_field, make_mf_correlated_field +from .bernoulli_energy import BernoulliEnergy diff --git a/nifty5/library/bernoulli_energy.py b/nifty5/library/bernoulli_energy.py new file mode 100644 index 000000000..535e61f50 --- /dev/null +++ b/nifty5/library/bernoulli_energy.py @@ -0,0 +1,60 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2018 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from numpy import inf, isnan + +from ..minimization.energy import Energy +from ..operators.sandwich_operator import SandwichOperator +from ..sugar import log, makeOp + + +class BernoulliEnergy(Energy): + def __init__(self, p, d): + """ + p: Model object + + + """ + super(BernoulliEnergy, self).__init__(p.position) + self._p = p + self._d = d + + p_val = self._p.value + print p_val.min(), p_val.max() + self._value = -self._d.vdot(log(p_val)) - (1. - d).vdot(log(1.-p_val)) + if isnan(self._value): + self._value = inf + metric = makeOp(1./((p_val) * (1.-p_val))) + self._gradient = self._p.gradient.adjoint_times(metric(p_val-d)) + + self._curvature = SandwichOperator.make(self._p.gradient, metric) + + def at(self, position): + return self.__class__(self._p.at(position), self._d) + + @property + def value(self): + return self._value + + @property + def gradient(self): + return self._gradient + + @property + def curvature(self): + return self._curvature -- GitLab