Commit 5a9a730d authored by Philipp Arras's avatar Philipp Arras

Define domains of likelihoods consistently

Related to Issue #62
parent 6a5ec66d
...@@ -86,9 +86,9 @@ if __name__ == '__main__': ...@@ -86,9 +86,9 @@ if __name__ == '__main__':
data = ift.Field.from_global_data(d_space, data) data = ift.Field.from_global_data(d_space, data)
# Compute likelihood and Hamiltonian # Compute likelihood and Hamiltonian
likelihood = ift.PoissonianEnergy(lamb, data)
ic_newton = ift.DeltaEnergyController(name='Newton', iteration_limit=100, ic_newton = ift.DeltaEnergyController(name='Newton', iteration_limit=100,
tol_rel_deltaE=1e-8) tol_rel_deltaE=1e-8)
likelihood = ift.PoissonianEnergy(data)(lamb)
minimizer = ift.NewtonCG(ic_newton) minimizer = ift.NewtonCG(ic_newton)
# Minimize the Hamiltonian # Minimize the Hamiltonian
......
...@@ -66,4 +66,4 @@ def make_adjust_variances(a, xi, position, samples=[], scaling=None, ...@@ -66,4 +66,4 @@ def make_adjust_variances(a, xi, position, samples=[], scaling=None,
if scaling is not None: if scaling is not None:
x = ScalingOperator(scaling, x.target)(x) x = ScalingOperator(scaling, x.target)(x)
return Hamiltonian(InverseGammaLikelihood(x, d_eval), ic_samp=ic_samp) return Hamiltonian(InverseGammaLikelihood(d_eval)(x), ic_samp=ic_samp)
...@@ -101,13 +101,12 @@ class GaussianEnergy(EnergyOperator): ...@@ -101,13 +101,12 @@ class GaussianEnergy(EnergyOperator):
class PoissonianEnergy(EnergyOperator): class PoissonianEnergy(EnergyOperator):
def __init__(self, op, d): def __init__(self, d):
self._op, self._d = op, d self._d = d
self._domain = d.domain self._domain = DomainTuple.make(d.domain)
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
x = self._op(x)
res = x.sum() - x.log().vdot(self._d) res = x.sum() - x.log().vdot(self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(res) return Field.scalar(res)
...@@ -118,13 +117,12 @@ class PoissonianEnergy(EnergyOperator): ...@@ -118,13 +117,12 @@ class PoissonianEnergy(EnergyOperator):
class InverseGammaLikelihood(EnergyOperator): class InverseGammaLikelihood(EnergyOperator):
def __init__(self, op, d): def __init__(self, d):
self._op, self._d = op, d self._d = d
self._domain = d.domain self._domain = DomainTuple.make(d.domain)
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
x = self._op(x)
res = 0.5*(x.log().sum() + (1./x).vdot(self._d)) res = 0.5*(x.log().sum() + (1./x).vdot(self._d))
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(res) return Field.scalar(res)
...@@ -135,14 +133,12 @@ class InverseGammaLikelihood(EnergyOperator): ...@@ -135,14 +133,12 @@ class InverseGammaLikelihood(EnergyOperator):
class BernoulliEnergy(EnergyOperator): class BernoulliEnergy(EnergyOperator):
def __init__(self, p, d): def __init__(self, d):
self._p = p
self._d = d self._d = d
self._domain = d.domain self._domain = DomainTuple.make(d.domain)
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
x = self._p(x)
v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d) v = x.log().vdot(-self._d) - (1.-x).log().vdot(1.-self._d)
if not isinstance(x, Linearization): if not isinstance(x, Linearization):
return Field.scalar(v) return Field.scalar(v)
......
...@@ -61,11 +61,7 @@ class Operator(NiftyMetaBase()): ...@@ -61,11 +61,7 @@ class Operator(NiftyMetaBase()):
def _check_input(self, x): def _check_input(self, x):
from ..linearization import Linearization from ..linearization import Linearization
print('checkinput')
d = x.target if isinstance(x, Linearization) else x.domain d = x.target if isinstance(x, Linearization) else x.domain
print(d)
print(self._domain)
print()
if self._domain != d: if self._domain != d:
raise ValueError("The operator's and field's domains don't match.") raise ValueError("The operator's and field's domains don't match.")
......
...@@ -66,7 +66,7 @@ class Energy_Tests(unittest.TestCase): ...@@ -66,7 +66,7 @@ class Energy_Tests(unittest.TestCase):
model = self.make_model(space_key='s1', space=space, seed=seed)['s1'] model = self.make_model(space_key='s1', space=space, seed=seed)['s1']
d = np.random.normal(10, size=space.shape)**2 d = np.random.normal(10, size=space.shape)**2
d = ift.Field.from_global_data(space, d) d = ift.Field.from_global_data(space, d)
energy = ift.InverseGammaLikelihood(ift.exp, d) energy = ift.InverseGammaLikelihood(d).exp()
ift.extra.check_value_gradient_consistency(energy, model, tol=1e-7) ift.extra.check_value_gradient_consistency(energy, model, tol=1e-7)
@expand(product( @expand(product(
...@@ -80,7 +80,7 @@ class Energy_Tests(unittest.TestCase): ...@@ -80,7 +80,7 @@ class Energy_Tests(unittest.TestCase):
space_key='s1', space=space, seed=seed)['s1'] space_key='s1', space=space, seed=seed)['s1']
d = np.random.poisson(120, size=space.shape) d = np.random.poisson(120, size=space.shape)
d = ift.Field.from_global_data(space, d) d = ift.Field.from_global_data(space, d)
energy = ift.PoissonianEnergy(ift.exp, d) energy = ift.PoissonianEnergy(d).exp()
ift.extra.check_value_gradient_consistency(energy, model, tol=1e-7) ift.extra.check_value_gradient_consistency(energy, model, tol=1e-7)
@expand(product( @expand(product(
...@@ -113,5 +113,5 @@ class Energy_Tests(unittest.TestCase): ...@@ -113,5 +113,5 @@ class Energy_Tests(unittest.TestCase):
model = model.positive_tanh() model = model.positive_tanh()
d = np.random.binomial(1, 0.1, size=space.shape) d = np.random.binomial(1, 0.1, size=space.shape)
d = ift.Field.from_global_data(space, d) d = ift.Field.from_global_data(space, d)
energy = ift.BernoulliEnergy(ift.positive_tanh, d) energy = ift.BernoulliEnergy(d).positive_tanh()
ift.extra.check_value_gradient_consistency(energy, model, tol=2e-7) ift.extra.check_value_gradient_consistency(energy, model, tol=2e-7)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment