From b27108cde264d4fb61421607e26469b28d29b266 Mon Sep 17 00:00:00 2001
From: Theo Steininger <theo.steininger@ipt.ai>
Date: Thu, 7 Sep 2017 12:04:09 +0200
Subject: [PATCH] Fixed volume factors in EnsembleLikelihood

---
 .../ensemble_likelihood/ensemble_likelihood.py      | 13 +++++++++----
 1 file changed, 9 insertions(+), 4 deletions(-)

diff --git a/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py b/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py
index 8bcf609..42c44fd 100644
--- a/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py
+++ b/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py
@@ -34,6 +34,8 @@ class EnsembleLikelihood(Likelihood):
         # B^{-1} c = (A_inv -
         #             A_inv U (I_k + U^dagger A_inv U)^{-1} U^dagger A_inv) c
 
+        weight = observable.domain[1].weight(1)
+
         k = observable.shape[0]
         n = observable.shape[1]
 
@@ -48,8 +50,11 @@ class EnsembleLikelihood(Likelihood):
         u_val = obs_val - obs_mean
 
         # compute quantities for OAS estimator
-        mu = np.vdot(u_val, u_val)/n
+        mu = np.vdot(u_val, u_val)*weight/n
         alpha = (np.einsum(u_val, [0, 1], u_val, [2, 1])**2).sum()
+        # correct the volume factor: one factor comes from the internal scalar
+        # product and one from the trace
+        alpha *= weight**2
 
         numerator = alpha + mu**2
         denominator = (k + 1) / (alpha - (mu**2)/n)
@@ -83,7 +88,7 @@ class EnsembleLikelihood(Likelihood):
         a_u_val = a_u.val.get_full_data()
         middle = (np.eye(k) +
                   np.einsum(u_val.conjugate(), [0, 1],
-                            a_u_val, [2, 1]))
+                            a_u_val, [2, 1])*weight)
         middle = np.linalg.inv(middle)
 #        result_array = np.zeros(k)
 #        for i in xrange(k):
@@ -114,8 +119,8 @@ class EnsembleLikelihood(Likelihood):
         second_summand = first_summand.copy_empty()
         second_summand.val = second_summand_val
 
-        result_1 = -c.dot(first_summand)
-        result_2 = -c.dot(second_summand)
+        result_1 = -c.vdot(first_summand)
+        result_2 = -c.vdot(second_summand)
         result = result_1 + result_2
         self.logger.debug("Calculated: %f + %f = %f" %
                           (result_1, result_2, result))
-- 
GitLab