diff --git a/demos/bernoulli_demo.py b/demos/bernoulli_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..830db17d3633a7e0bbea51871d35ca3493a2336e --- /dev/null +++ b/demos/bernoulli_demo.py @@ -0,0 +1,75 @@ +import nifty5 as ift +import numpy as np + + + + +if __name__ == '__main__': + # ABOUT THIS CODE + np.random.seed(41) + + # Set up the position space of the signal + # + # # One dimensional regular grid with uniform exposure + # position_space = ift.RGSpace(1024) + # exposure = np.ones(position_space.shape) + + # Two-dimensional regular grid with inhomogeneous exposure + position_space = ift.RGSpace([512, 512]) + + # # Sphere with with uniform exposure + # position_space = ift.HPSpace(128) + # exposure = ift.Field.full(position_space, 1.) + + # Defining harmonic space and transform + harmonic_space = position_space.get_default_codomain() + HT = ift.HarmonicTransformOperator(harmonic_space, position_space) + + domain = ift.MultiDomain.make({'xi': harmonic_space}) + position = ift.from_random('normal', domain) + + # Define power spectrum and amplitudes + def sqrtpspec(k): + return 1. / (20. + k**2) + + p_space = ift.PowerSpace(harmonic_space) + pd = ift.PowerDistributor(harmonic_space, p_space) + a = ift.PS_field(p_space, sqrtpspec) + A = pd(a) + + # Set up a sky model + xi = ift.Variable(position)['xi'] + logsky_h = xi * A + logsky = HT(logsky_h) + sky = ift.PointwisePositiveTanh(logsky) + + GR = ift.GeometryRemover(position_space) + # Set up instrumental response + R = GR + + # Generate mock data + d_space = R.target[0] + p = R(sky) + mock_position = ift.from_random('normal', p.position.domain) + pp = p.at(mock_position).value + data = np.random.binomial(1,pp.to_global_data().astype(np.float64)) + data = ift.Field.from_global_data(d_space, data) + + # Compute likelihood and Hamiltonian + position = ift.from_random('normal', p.position.domain) + likelihood = ift.BernoulliEnergy(p, data) + ic_cg = ift.GradientNormController(iteration_limit=50) + ic_newton = ift.GradientNormController(name='Newton', iteration_limit=30, + tol_abs_gradnorm=1e-3) + minimizer = ift.RelaxedNewton(ic_newton) + ic_sampling = ift.GradientNormController(iteration_limit=100) + + # Minimize the Hamiltonian + H = ift.Hamiltonian(likelihood, ic_sampling) + H = H.makeInvertible(ic_cg) + # minimizer = ift.SteepestDescent(ic_newton) + H, convergence = minimizer(H) + + # result_sky = sky.at(H.position).value + # ift.plot(result_sky) + diff --git a/nifty5/__init__.py b/nifty5/__init__.py index 0e1b17e390d3aef8234b72eba6195f8ee67158cc..fe63c27fb1f1fcb4545c1e051ba00dd28faffb59 100644 --- a/nifty5/__init__.py +++ b/nifty5/__init__.py @@ -6,7 +6,6 @@ from .domain_tuple import DomainTuple from .field import Field from .nonlinearities import Exponential, Linear, PositiveTanh, Tanh - from .models import * from .operators import * from .probing.utils import probe_with_posterior_samples, probe_diagonal, \ diff --git a/nifty5/library/__init__.py b/nifty5/library/__init__.py index 2678799c9b90ce4c4879a418a54102cfb5f504ae..9c8962d50609a4edce4c9a0515436c0401c5ba36 100644 --- a/nifty5/library/__init__.py +++ b/nifty5/library/__init__.py @@ -7,3 +7,4 @@ from .poissonian_energy import PoissonianEnergy from .wiener_filter_curvature import WienerFilterCurvature from .wiener_filter_energy import WienerFilterEnergy from .correlated_fields import make_correlated_field, make_mf_correlated_field +from .bernoulli_energy import BernoulliEnergy diff --git a/nifty5/library/bernoulli_energy.py b/nifty5/library/bernoulli_energy.py new file mode 100644 index 0000000000000000000000000000000000000000..535e61f5066c50336a757f53e79960c16c95eb6d --- /dev/null +++ b/nifty5/library/bernoulli_energy.py @@ -0,0 +1,60 @@ +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# +# Copyright(C) 2013-2018 Max-Planck-Society +# +# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik +# and financially supported by the Studienstiftung des deutschen Volkes. + +from numpy import inf, isnan + +from ..minimization.energy import Energy +from ..operators.sandwich_operator import SandwichOperator +from ..sugar import log, makeOp + + +class BernoulliEnergy(Energy): + def __init__(self, p, d): + """ + p: Model object + + + """ + super(BernoulliEnergy, self).__init__(p.position) + self._p = p + self._d = d + + p_val = self._p.value + print p_val.min(), p_val.max() + self._value = -self._d.vdot(log(p_val)) - (1. - d).vdot(log(1.-p_val)) + if isnan(self._value): + self._value = inf + metric = makeOp(1./((p_val) * (1.-p_val))) + self._gradient = self._p.gradient.adjoint_times(metric(p_val-d)) + + self._curvature = SandwichOperator.make(self._p.gradient, metric) + + def at(self, position): + return self.__class__(self._p.at(position), self._d) + + @property + def value(self): + return self._value + + @property + def gradient(self): + return self._gradient + + @property + def curvature(self): + return self._curvature