diff --git a/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py b/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py index a3ef77e2c7b5d21ea2820a473f46153439f00c9a..4c35fa8bb61b62d0f8473c4bd645c1a621d68ca4 100644 --- a/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py +++ b/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py @@ -2,7 +2,7 @@ import numpy as np -from nifty import DiagonalOperator +from nifty import DiagonalOperator, FieldArray, Field from imagine.likelihoods.likelihood import Likelihood from imagine.create_ring_profile import create_ring_profile @@ -12,13 +12,24 @@ class EnsembleLikelihood(Likelihood): def __init__(self, observable_name, measured_data, data_covariance_operator, profile=None): self.observable_name = observable_name - self.measured_data = measured_data + self.measured_data = self._strip_data(measured_data) self.data_covariance_operator = data_covariance_operator + self.data_covariance_includes_profile = False + if profile is None: profile = create_ring_profile( self.measured_data.val.get_full_data()) self.profile = profile + def _strip_data(self, data): + # if the first element in the domain tuple is a FieldArray we must + # extract the data + if isinstance(data.domain[0], FieldArray): + stripped_data = Field(domain=data.domain[1:], + val=data.val.get_full_data()[0], + distribution_strategy='not') + return stripped_data + def __call__(self, observable): field = observable[self.observable_name] return self._process_simple_field(field, @@ -80,6 +91,8 @@ class EnsembleLikelihood(Likelihood): "DiagonalOperator.") A_bare_diagonal = data_covariance_operator.diagonal(bare=True) + if not self.data_covariance_includes_profile: + A_bare_diagonal *= (profile**2) A_bare_diagonal.val += rho*mu A = DiagonalOperator( domain=data_covariance_operator.domain, @@ -97,6 +110,10 @@ class EnsembleLikelihood(Likelihood): middle = np.linalg.inv(middle) c = measured_data - obs_mean + # If the data was incomplete, i.e. contains np.NANs, set those values + # to zero. + np.nan_to_num(c, copy=False) + # assuming that A == A^dagger, this can be shortend # a_c = A.inverse_times(c) # u_a_c = a_c.dot(U, spaces=1)