Commit c98ebc8e authored by Theo Steininger's avatar Theo Steininger

Refactored simple likelihood.

parent 41320f6c
...@@ -16,15 +16,6 @@ class EnsembleLikelihood(Likelihood): ...@@ -16,15 +16,6 @@ class EnsembleLikelihood(Likelihood):
data_covariance = data_covariance.val.get_full_data() data_covariance = data_covariance.val.get_full_data()
self.data_covariance = data_covariance self.data_covariance = data_covariance
def _strip_data(self, data):
# if the first element in the domain tuple is a FieldArray we must
# extract the data
if isinstance(data.domain[0], FieldArray):
data = Field(domain=data.domain[1:],
val=data.val.get_full_data()[0],
distribution_strategy='not')
return data
def __call__(self, observable): def __call__(self, observable):
field = observable[self.observable_name] field = observable[self.observable_name]
return self._process_simple_field(field, return self._process_simple_field(field,
...@@ -58,7 +49,6 @@ class EnsembleLikelihood(Likelihood): ...@@ -58,7 +49,6 @@ class EnsembleLikelihood(Likelihood):
rho = 1 rho = 1
else: else:
rho = np.min([1, numerator/denominator]) rho = np.min([1, numerator/denominator])
self.logger.debug("rho: %f = %f / %f" % (rho, numerator, denominator)) self.logger.debug("rho: %f = %f / %f" % (rho, numerator, denominator))
# rescale U half/half # rescale U half/half
...@@ -68,7 +58,7 @@ class EnsembleLikelihood(Likelihood): ...@@ -68,7 +58,7 @@ class EnsembleLikelihood(Likelihood):
self.logger.info(('rho*mu', rho*mu, self.logger.info(('rho*mu', rho*mu,
'rho', rho, 'rho', rho,
'mu', mu, 'mu', mu,
'alhpa', alpha)) 'alpha', alpha))
A_diagonal_val += rho*mu A_diagonal_val += rho*mu
a_u_val = u_val/A_diagonal_val a_u_val = u_val/A_diagonal_val
...@@ -82,7 +72,7 @@ class EnsembleLikelihood(Likelihood): ...@@ -82,7 +72,7 @@ class EnsembleLikelihood(Likelihood):
# If the data was incomplete, i.e. contains np.NANs, set those values # If the data was incomplete, i.e. contains np.NANs, set those values
# to zero. # to zero.
c.val.data = np.nan_to_num(c.val.data) c = np.nan_to_num(c)
# assuming that A == A^dagger, this can be shortend # assuming that A == A^dagger, this can be shortend
# a_c = A.inverse_times(c) # a_c = A.inverse_times(c)
# u_a_c = a_c.dot(U, spaces=1) # u_a_c = a_c.dot(U, spaces=1)
...@@ -94,7 +84,7 @@ class EnsembleLikelihood(Likelihood): ...@@ -94,7 +84,7 @@ class EnsembleLikelihood(Likelihood):
# Pure NIFTy is # Pure NIFTy is
# u_a_c = c.dot(a_u, spaces=1) # u_a_c = c.dot(a_u, spaces=1)
# u_a_c_val = u_a_c.val.get_full_data() # u_a_c_val = u_a_c.val.get_full_data()
c_val = c.val.get_full_data() c_val = c
u_a_c_val = np.einsum(c_val, [1], a_u_val, [0, 1]) u_a_c_val = np.einsum(c_val, [1], a_u_val, [0, 1])
first_summand_val = c_val/A_diagonal_val first_summand_val = c_val/A_diagonal_val
......
...@@ -4,8 +4,19 @@ import abc ...@@ -4,8 +4,19 @@ import abc
from keepers import Loggable from keepers import Loggable
from nifty import FieldArray
class Likelihood(Loggable, object): class Likelihood(Loggable, object):
@abc.abstractmethod @abc.abstractmethod
def __call__(self, observables): def __call__(self, observables):
raise NotImplementedError raise NotImplementedError
def _strip_data(self, data):
# if the first element in the domain tuple is a FieldArray we must
# extract the data
if isinstance(data.domain[0], FieldArray):
data = data.val.get_full_data()[0]
else:
data = data.val.get_full_data()
return data
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import numpy as np
from nifty import Field
from imagine.likelihoods.likelihood import Likelihood from imagine.likelihoods.likelihood import Likelihood
class SimpleLikelihood(Likelihood): class SimpleLikelihood(Likelihood):
def __init__(self, measured_data, data_covariance_operator=None): def __init__(self, measured_data, data_covariance=None):
self.measured_data = measured_data self.measured_data = measured_data
self.data_covariance_operator = data_covariance_operator if isinstance(data_covariance, Field):
data_covariance = data_covariance.val.get_full_data()
self.data_covariance = data_covariance
def __call__(self, observable): def __call__(self, observable):
data = self.measured_data.val.get_full_data() data = self.measured_data
obs_mean = observable.ensemble_mean().val.get_full_data() obs_mean = observable.ensemble_mean().val.get_full_data()
diff = data - obs_mean diff = data - obs_mean
if self.data_covariance_operator is not None: if self.data_covariance is not None:
right = self.data_covariance_operator.inverse_times(diff) right = diff/self.data_covariance_operator
else: else:
right = diff right = diff
return -0.5 * diff.conjugate().vdot(right) return -0.5 * np.vdot(diff, right)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment