From 2e82b9bccd849f8923845e7ca502d4e34aa9599c Mon Sep 17 00:00:00 2001
From: Theo Steininger <theo.steininger@ultimanet.de>
Date: Thu, 14 Dec 2017 03:44:50 +0100
Subject: [PATCH] Added sqrt(n) to U in EnsembleLikelihood

---
 .../ensemble_likelihood.py                    | 35 +++++++++----------
 imagine/likelihoods/likelihood/likelihood.py  |  3 ++
 2 files changed, 20 insertions(+), 18 deletions(-)

diff --git a/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py b/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py
index 66ab2c7..88c3537 100644
--- a/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py
+++ b/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py
@@ -36,10 +36,11 @@ class EnsembleLikelihood(Likelihood):
         obs_val = observable.val.get_full_data()
         obs_mean = observable.ensemble_mean().val.get_full_data()
 
-        u_val = obs_val - obs_mean
+        U = obs_val - obs_mean
+        U *= np.sqrt(n)
         # compute quantities for OAS estimator
-        mu = np.vdot(u_val, u_val)/k/n
-        alpha = (np.einsum(u_val, [0, 1], u_val, [2, 1])**2).sum()
+        mu = np.vdot(U, U)/k/n
+        alpha = (np.einsum(U, [0, 1], U, [2, 1])**2).sum()
         alpha /= k**2
 
         numerator = (1 - 2./n)*alpha + (mu*n)**2
@@ -52,22 +53,21 @@ class EnsembleLikelihood(Likelihood):
         self.logger.debug("rho: %f = %f / %f" % (rho, numerator, denominator))
 
         # rescale U half/half
-        u_val *= np.sqrt(1-rho) / np.sqrt(k)
+        V = U * np.sqrt(1-rho) / np.sqrt(k)
 
-        A_diagonal_val = data_covariance
-        self.logger.info(('A_mean', np.mean(A_diagonal_val),
+        self.logger.info(('data_cov', np.mean(data_covariance),
                           'rho*mu', rho*mu,
                           'rho', rho,
                           'mu', mu,
                           'alpha', alpha))
-        A_diagonal_val += rho*mu
+        B = data_covariance + rho*mu
 
-        a_u_val = u_val/A_diagonal_val
+        V_B = V/B
 
         # build middle-matrix (kxk)
         middle = (np.eye(k) +
-                  np.einsum(u_val.conjugate(), [0, 1],
-                            a_u_val, [2, 1]))
+                  np.einsum(V.conjugate(), [0, 1],
+                            V_B, [2, 1]))
         middle = np.linalg.inv(middle)
         c = measured_data - obs_mean
 
@@ -85,23 +85,22 @@ class EnsembleLikelihood(Likelihood):
         # Pure NIFTy is
         # u_a_c = c.dot(a_u, spaces=1)
         # u_a_c_val = u_a_c.val.get_full_data()
-        c_val = c
-        u_a_c_val = np.einsum(c_val, [1], a_u_val, [0, 1])
+        V_B_c = np.einsum(c, [1], V_B, [0, 1])
 
-        first_summand_val = c_val/A_diagonal_val
-        second_summand_val = np.einsum(middle, [0, 1], u_a_c_val, [1])
-        second_summand_val = np.einsum(a_u_val, [0, 1],
+        first_summand_val = c/B
+        second_summand_val = np.einsum(middle, [0, 1], V_B_c, [1])
+        second_summand_val = np.einsum(V_B, [0, 1],
                                        second_summand_val, [0])
 #        # second_summand_val *= -1
 #        second_summand = first_summand.copy_empty()
 #        second_summand.val = second_summand_val
 
-        result_1 = np.vdot(c_val, first_summand_val)
-        result_2 = -np.vdot(c_val, second_summand_val)
+        result_1 = np.vdot(c, first_summand_val)
+        result_2 = -np.vdot(c, second_summand_val)
 
         # compute regularizing determinant of the covariance
         # det(A + UV^T) =  det(A) det(I + V^T A^-1 U)
-        log_det_1 = np.sum(np.log(A_diagonal_val))
+        log_det_1 = np.sum(np.log(B))
         (sign, log_det_2) = np.linalg.slogdet(middle)
         if sign < 0:
             self.logger.error("Negative determinant of covariance!")
diff --git a/imagine/likelihoods/likelihood/likelihood.py b/imagine/likelihoods/likelihood/likelihood.py
index dc99f45..1192276 100644
--- a/imagine/likelihoods/likelihood/likelihood.py
+++ b/imagine/likelihoods/likelihood/likelihood.py
@@ -15,6 +15,9 @@ class Likelihood(Loggable, object):
     def _strip_data(self, data):
         # if the first element in the domain tuple is a FieldArray we must
         # extract the data
+        if not hasattr(data, 'domain'):
+            return data
+
         if isinstance(data.domain[0], FieldArray):
             data = data.val.get_full_data()[0]
         else:
-- 
GitLab