From a2ee6e6041ee80fd883df30917c88cf455480c8f Mon Sep 17 00:00:00 2001
From: Thomas Purcell <purcell@fhi-berlin.mpg.de>
Date: Fri, 28 Aug 2020 17:44:31 +0200
Subject: [PATCH] Move jackknife_cv_conv_est to seperate file

In case more functions are needed
---
 .../postprocess/check_cv_convergence.py       | 44 +++++++++++++++++++
 1 file changed, 44 insertions(+)
 create mode 100644 src/python/postprocess/check_cv_convergence.py

diff --git a/src/python/postprocess/check_cv_convergence.py b/src/python/postprocess/check_cv_convergence.py
new file mode 100644
index 00000000..12073daa
--- /dev/null
+++ b/src/python/postprocess/check_cv_convergence.py
@@ -0,0 +1,44 @@
+import numpy as np
+from glob import glob
+
+
+def jackknife_cv_conv_est(dir_expr):
+    """Get the jackknife variance of the CV test error
+
+    Args:
+        dir_expr (str): Regular expression for the directory list
+
+    Returns:
+        avg_error: The average rmse of the test error
+        variance: The jackknife estimate of the variance of the test RMSE
+    """
+    train_model_list = sorted(
+        glob(dir_expr + "/models/train_*_model_0.dat"),
+        key=lambda s: int(s.split("/")[-1].split("_")[2]),
+    )
+    test_model_list = sorted(
+        glob(dir_expr + "/models/test_*_model_0.dat"),
+        key=lambda s: int(s.split("/")[-1].split("_")[2]),
+    )
+    n_dim = int(train_model_list[-1].split("/")[-1].split("_")[2])
+
+    models = [
+        Model(train_file, test_file)
+        for train_file, test_file in zip(train_model_list, test_model_list)
+    ]
+
+    test_rmse = np.array([model.test_rmse for model in models]).reshape(n_dim, -1)
+    x_bar_i = []
+    for dim_error in test_rmse:
+        x_bar_i.append([])
+        for ii in range(len(dim_error)):
+            x_bar_i[-1].append(np.delete(dim_error, ii).mean())
+
+    x_bar_i = np.array(x_bar_i)
+    avg_error = x_bar_i.mean(axis=1)
+    variance = (
+        (test_rmse.shape[1] - 1.0)
+        / test_rmse.shape[1]
+        * np.sum((x_bar_i - avg_error.reshape(n_dim, 1)) ** 2.0, axis=1)
+    )
+    return avg_error, variance
-- 
GitLab