Commit f10f5f20 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'laplace_op' into 'NIFTy_7'

Laplace distribution operator for L1 regularisation

See merge request !599
parents 75760266 acef5270
Pipeline #96745 passed with stages
in 11 minutes and 57 seconds
...@@ -77,7 +77,7 @@ from .sugar import * ...@@ -77,7 +77,7 @@ from .sugar import *
from .plot import Plot from .plot import Plot
from .library.special_distributions import InverseGammaOperator, UniformOperator from .library.special_distributions import InverseGammaOperator, UniformOperator, LaplaceOperator
from .library.los_response import LOSResponse from .library.los_response import LOSResponse
from .library.dynamic_operator import (dynamic_operator, from .library.dynamic_operator import (dynamic_operator,
dynamic_lightcone_operator) dynamic_lightcone_operator)
......
...@@ -11,13 +11,13 @@ ...@@ -11,13 +11,13 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2019 Max-Planck-Society # Copyright(C) 2013-2021 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np import numpy as np
from scipy.interpolate import CubicSpline from scipy.interpolate import CubicSpline
from scipy.stats import invgamma, norm from scipy.stats import invgamma, laplace, norm
from .. import random from .. import random
from ..domain_tuple import DomainTuple from ..domain_tuple import DomainTuple
...@@ -90,8 +90,7 @@ class _InterpolationOperator(Operator): ...@@ -90,8 +90,7 @@ class _InterpolationOperator(Operator):
def InverseGammaOperator(domain, alpha, q, delta=1e-2): def InverseGammaOperator(domain, alpha, q, delta=1e-2):
"""Transforms a Gaussian with unit covariance and zero mean into an """Transform a standard normal into an inverse gamma distribution.
inverse gamma distribution.
The pdf of the inverse gamma distribution is defined as follows: The pdf of the inverse gamma distribution is defined as follows:
...@@ -126,9 +125,9 @@ def InverseGammaOperator(domain, alpha, q, delta=1e-2): ...@@ -126,9 +125,9 @@ def InverseGammaOperator(domain, alpha, q, delta=1e-2):
class UniformOperator(Operator): class UniformOperator(Operator):
""" """Transform a standard normal into a uniform distribution.
Transforms a Gaussian with unit covariance and zero mean into a uniform
distribution. The uniform distribution's support is ``[loc, loc + scale]``. The uniform distribution's support is ``[loc, loc + scale]``.
Parameters Parameters
---------- ----------
...@@ -157,3 +156,37 @@ class UniformOperator(Operator): ...@@ -157,3 +156,37 @@ class UniformOperator(Operator):
def inverse(self, field): def inverse(self, field):
res = norm._ppf(field.val/self._scale - self._loc) res = norm._ppf(field.val/self._scale - self._loc)
return Field(field.domain, res) return Field(field.domain, res)
class LaplaceOperator(Operator):
"""Transform a standard normal to a Laplace distribution.
Parameters
-----------
domain : Domain, tuple of Domain or DomainTuple
The domain on which the field shall be defined. This is at the same
time the domain and the target of the operator.
loc : float
scale : float
"""
def __init__(self, domain, loc=0, scale=1):
self._target = self._domain = DomainTuple.make(domain)
self._loc = float(loc)
self._scale = float(scale)
def apply(self, x):
self._check_input(x)
lin = x.jac is not None
xval = x.val.val if lin else x.val
res = Field(self._target, laplace.ppf(norm._cdf(xval), self._loc, self._scale))
if not lin:
return res
y = norm._cdf(xval)
y = self._scale * np.where(y > 0.5, 1/(1-y), 1/y)
jac = makeOp(Field(self.domain, y*norm._pdf(xval)))
return x.new(res, jac)
def inverse(self, x):
res = laplace._cdf(x.val)
return Field(x.domain, res)
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2020 Max-Planck-Society # Copyright(C) 2013-2021 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
...@@ -88,14 +88,22 @@ def testBinary(type1, type2, space, seed): ...@@ -88,14 +88,22 @@ def testBinary(type1, type2, space, seed):
ift.extra.check_operator(model, pos, ntries=ntries) ift.extra.check_operator(model, pos, ntries=ntries)
def testSpecialDistributionOps(space, seed): def testInverseGamma(space, seed):
with ift.random.Context(seed): with ift.random.Context(seed):
pos = ift.from_random(space, 'normal') pos = ift.from_random(space, 'normal')
alpha = 1.5 alpha = 1.5
q = 0.73 q = 0.73
model = ift.InverseGammaOperator(space, alpha, q) model = ift.InverseGammaOperator(space, alpha, q)
ift.extra.check_operator(model, pos, ntries=20) ift.extra.check_operator(model, pos, ntries=20)
model = ift.UniformOperator(space, alpha, q)
@pmp("loc", [0, 13.2])
@pmp("scale", [1, 551.09])
@pmp("op", [ift.UniformOperator, ift.LaplaceOperator])
def testSpecialDistributionOps(space, seed, loc, scale, op):
with ift.random.Context(seed):
pos = ift.from_random(space, 'normal')
model = op(space, loc, scale)
ift.extra.check_operator(model, pos, ntries=20) ift.extra.check_operator(model, pos, ntries=20)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment