Refactored simple likelihood.

parent 41320f6c
......@@ -16,15 +16,6 @@ class EnsembleLikelihood(Likelihood):
data_covariance = data_covariance.val.get_full_data()
self.data_covariance = data_covariance
def _strip_data(self, data):
# if the first element in the domain tuple is a FieldArray we must
# extract the data
if isinstance(data.domain[0], FieldArray):
data = Field(domain=data.domain[1:],
val=data.val.get_full_data()[0],
distribution_strategy='not')
return data
def __call__(self, observable):
field = observable[self.observable_name]
return self._process_simple_field(field,
......@@ -58,7 +49,6 @@ class EnsembleLikelihood(Likelihood):
rho = 1
else:
rho = np.min([1, numerator/denominator])
self.logger.debug("rho: %f = %f / %f" % (rho, numerator, denominator))
# rescale U half/half
......@@ -68,7 +58,7 @@ class EnsembleLikelihood(Likelihood):
self.logger.info(('rho*mu', rho*mu,
'rho', rho,
'mu', mu,
'alhpa', alpha))
'alpha', alpha))
A_diagonal_val += rho*mu
a_u_val = u_val/A_diagonal_val
......@@ -82,7 +72,7 @@ class EnsembleLikelihood(Likelihood):
# If the data was incomplete, i.e. contains np.NANs, set those values
# to zero.
c.val.data = np.nan_to_num(c.val.data)
c = np.nan_to_num(c)
# assuming that A == A^dagger, this can be shortend
# a_c = A.inverse_times(c)
# u_a_c = a_c.dot(U, spaces=1)
......@@ -94,7 +84,7 @@ class EnsembleLikelihood(Likelihood):
# Pure NIFTy is
# u_a_c = c.dot(a_u, spaces=1)
# u_a_c_val = u_a_c.val.get_full_data()
c_val = c.val.get_full_data()
c_val = c
u_a_c_val = np.einsum(c_val, [1], a_u_val, [0, 1])
first_summand_val = c_val/A_diagonal_val
......
......@@ -4,8 +4,19 @@ import abc
from keepers import Loggable
from nifty import FieldArray
class Likelihood(Loggable, object):
@abc.abstractmethod
def __call__(self, observables):
raise NotImplementedError
def _strip_data(self, data):
# if the first element in the domain tuple is a FieldArray we must
# extract the data
if isinstance(data.domain[0], FieldArray):
data = data.val.get_full_data()[0]
else:
data = data.val.get_full_data()
return data
# -*- coding: utf-8 -*-
import numpy as np
from nifty import Field
from imagine.likelihoods.likelihood import Likelihood
class SimpleLikelihood(Likelihood):
def __init__(self, measured_data, data_covariance_operator=None):
def __init__(self, measured_data, data_covariance=None):
self.measured_data = measured_data
self.data_covariance_operator = data_covariance_operator
if isinstance(data_covariance, Field):
data_covariance = data_covariance.val.get_full_data()
self.data_covariance = data_covariance
def __call__(self, observable):
data = self.measured_data.val.get_full_data()
data = self.measured_data
obs_mean = observable.ensemble_mean().val.get_full_data()
diff = data - obs_mean
if self.data_covariance_operator is not None:
right = self.data_covariance_operator.inverse_times(diff)
if self.data_covariance is not None:
right = diff/self.data_covariance_operator
else:
right = diff
return -0.5 * diff.conjugate().vdot(right)
return -0.5 * np.vdot(diff, right)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment