From cc683044e74e574a5391bdecfc381ce8eba26c09 Mon Sep 17 00:00:00 2001
From: Theo Steininger <theo.steininger@ipt.ai>
Date: Sun, 27 Aug 2017 02:27:36 +0200
Subject: [PATCH] Implemented first version of OAS covariance estimator for
 ensemble likelihood.

---
 .../ensemble_likelihood.py                    | 38 +++++++++++++++++--
 1 file changed, 35 insertions(+), 3 deletions(-)

diff --git a/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py b/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py
index fd40180..627dc8d 100644
--- a/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py
+++ b/imagine/likelihoods/ensemble_likelihood/ensemble_likelihood.py
@@ -2,6 +2,8 @@
 
 import numpy as np
 
+from nifty import DiagonalOperator
+
 from imagine.likelihoods.likelihood import Likelihood
 from imagine.create_ring_profile import create_ring_profile
 
@@ -24,18 +26,17 @@ class EnsembleLikelihood(Likelihood):
                                           self.data_covariance_operator,
                                           self.profile)
 
-    def _process_simple_field(self, field, measured_data,
+    def _process_simple_field(self, observable, measured_data,
                               data_covariance_operator, profile):
         # https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula#Generalization
         # B = A^{-1} + U U^dagger
         # A = data_covariance
         # B^{-1} c = (A_inv -
         #             A_inv U (I_k + U^dagger A_inv U)^{-1} U^dagger A_inv) c
-        observable = field
 
         k = observable.shape[0]
+        n = observable.shape[1]
 
-        A = data_covariance_operator
         obs_val = observable.val.get_full_data()
         obs_mean = observable.ensemble_mean().get_full_data()
 
@@ -45,8 +46,39 @@ class EnsembleLikelihood(Likelihood):
         measured_data = measured_data / profile
 
         u_val = obs_val - obs_mean
+
+        # compute quantities for OAS estimator
+        mu = np.vdot(u_val, u_val)/n
+        alpha = 0.
+        for i in xrange(n):
+            alpha += np.sum(u_val.T.dot(u_val[:, i])**2)
+
+        numerator = alpha + mu**2
+        denominator = (k + 1) / (alpha - (mu**2)/n)
+
+        if denominator == 0:
+            rho = 1
+        else:
+            rho = np.min(1, numerator/denominator)
+
+        # rescale U half/half
+        u_val *= np.sqrt(1-rho)
         U = observable.copy_empty()
         U.val = u_val
+
+        # we assume that data_covariance_operator is a DiagonalOperator
+        if not isinstance(data_covariance_operator, DiagonalOperator):
+            raise TypeError("data_covariance_operator must be a NIFTY "
+                            "DiagonalOperator.")
+
+        A_bare_diagonal = data_covariance_operator.diagonal(bare=True)
+        A_bare_diagonal.val += rho*mu
+        A = DiagonalOperator(
+                    domain=data_covariance_operator.domain,
+                    diagonal=A_bare_diagonal,
+                    bare=True, copy=False,
+                    default_spaces=data_covariance_operator.default_spaces)
+
         a_u = A.inverse_times(U, spaces=1)
 
         # build middle-matrix (kxk)
-- 
GitLab