Commit 5a9a730d authored by Philipp Arras's avatar Philipp Arras

Define domains of likelihoods consistently

Related to Issue #62
parent 6a5ec66d
......@@ -86,9 +86,9 @@ if __name__ == '__main__':
data = ift.Field.from_global_data(d_space, data)
# Compute likelihood and Hamiltonian
likelihood = ift.PoissonianEnergy(lamb, data)
ic_newton = ift.DeltaEnergyController(name='Newton', iteration_limit=100,
tol_rel_deltaE=1e-8)
likelihood = ift.PoissonianEnergy(data)(lamb)
minimizer = ift.NewtonCG(ic_newton)
# Minimize the Hamiltonian
......
......@@ -66,4 +66,4 @@ def make_adjust_variances(a, xi, position, samples=[], scaling=None,
if scaling is not None:
x = ScalingOperator(scaling, x.target)(x)
return Hamiltonian(InverseGammaLikelihood(x, d_eval), ic_samp=ic_samp)
return Hamiltonian(InverseGammaLikelihood(d_eval)(x), ic_samp=ic_samp)
......@@ -101,13 +101,12 @@ class GaussianEnergy(EnergyOperator):
class PoissonianEnergy(EnergyOperator):
def __init__(self, op, d):
self._op, self._d = op, d
self._domain = d.domain
def __init__(self, d):
self._d = d
self._domain = DomainTuple.make(d.domain)
def apply(self, x):
self._check_input(x)
x = self._op(x)
res = x.sum() - x.log().vdot(self._d)
if not isinstance(x, Linearization):
return Field.scalar(res)
......@@ -118,13 +117,12 @@ class PoissonianEnergy(EnergyOperator):
class InverseGammaLikelihood(EnergyOperator):
def __init__(self, op, d):
self._op, self._d = op, d
self._domain = d.domain
def __init__(self, d):
self._d = d
self._domain = DomainTuple.make(d.domain)
def apply(self, x):
self._check_input(x)
x = self._op(x)
res = 0.5*(x.log().sum() + (1./x).vdot(self._d))
if not isinstance(x, Linearization):
return Field.scalar(res)
......@@ -135,14 +133,12 @@ class InverseGammaLikelihood(EnergyOperator):
class BernoulliEnergy(EnergyOperator):
def __init__(self, p, d):
self._p = p
def __init__(self, d):
self._d = d
self._domain = d.domain
self._domain = DomainTuple.make(d.domain)
def apply(self, x):
self._check_input(x)
x = self._p(x)
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
if not isinstance(x, Linearization):
return Field.scalar(v)
......
......@@ -61,11 +61,7 @@ class Operator(NiftyMetaBase()):
def _check_input(self, x):
from ..linearization import Linearization
print('checkinput')
d = x.target if isinstance(x, Linearization) else x.domain
print(d)
print(self._domain)
print()
if self._domain != d:
raise ValueError("The operator's and field's domains don't match.")
......
......@@ -66,7 +66,7 @@ class Energy_Tests(unittest.TestCase):
model = self.make_model(space_key='s1', space=space, seed=seed)['s1']
d = np.random.normal(10, size=space.shape)**2
d = ift.Field.from_global_data(space, d)
energy = ift.InverseGammaLikelihood(ift.exp, d)
energy = ift.InverseGammaLikelihood(d).exp()
ift.extra.check_value_gradient_consistency(energy, model, tol=1e-7)
@expand(product(
......@@ -80,7 +80,7 @@ class Energy_Tests(unittest.TestCase):
space_key='s1', space=space, seed=seed)['s1']
d = np.random.poisson(120, size=space.shape)
d = ift.Field.from_global_data(space, d)
energy = ift.PoissonianEnergy(ift.exp, d)
energy = ift.PoissonianEnergy(d).exp()
ift.extra.check_value_gradient_consistency(energy, model, tol=1e-7)
@expand(product(
......@@ -113,5 +113,5 @@ class Energy_Tests(unittest.TestCase):
model = model.positive_tanh()
d = np.random.binomial(1, 0.1, size=space.shape)
d = ift.Field.from_global_data(space, d)
energy = ift.BernoulliEnergy(ift.positive_tanh, d)
energy = ift.BernoulliEnergy(d).positive_tanh()
ift.extra.check_value_gradient_consistency(energy, model, tol=2e-7)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment