From 6a3dcfceb73b8ff816f94f3ffc509f816264f7da Mon Sep 17 00:00:00 2001
From: Thomas Purcell <purcell@fhi-berlin.mpg.de>
Date: Fri, 25 Sep 2020 10:41:30 +0200
Subject: [PATCH] update to postprocessing

---
 src/python/postprocess/__init__.py            |  8 ++--
 src/python/postprocess/plotting/__init__.py   |  1 +
 .../{default.mpstyle => default.mplstyle}     |  0
 src/python/postprocess/plotting/utils.py      | 46 +++++++++++++++++++
 src/python/postprocess/utils.py               | 41 +++++++++++++++++
 5 files changed, 93 insertions(+), 3 deletions(-)
 rename src/python/postprocess/plotting/{default.mpstyle => default.mplstyle} (100%)
 create mode 100644 src/python/postprocess/plotting/utils.py
 create mode 100644 src/python/postprocess/utils.py

diff --git a/src/python/postprocess/__init__.py b/src/python/postprocess/__init__.py
index d3aa9e5b..1bb9f7c3 100644
--- a/src/python/postprocess/__init__.py
+++ b/src/python/postprocess/__init__.py
@@ -1,7 +1,9 @@
 from cpp_sisso.postprocess.check_cv_convergence import jackknife_cv_conv_est
 from cpp_sisso.postprocess.plotting import plt, config
-from cpp_sisso.postprocess.plotting.utils import adjust_box_widths
+from cpp_sisso.postprocess.plotting.utils import adjust_box_widths, latexify
 from cpp_sisso.postprocess.utils import get_models
+from cpp_sisso import ModelRegressor, ModelClassifier
+
 
 import numpy as np
 import pandas as pd
@@ -107,7 +109,7 @@ def generate_plot(dir_expr, filename, fig_settings=None):
         ]
     elif fig_config["plot_options"]["type"] == "train":
         # Set yaxis label
-        ax.set_ylabel("Absolute Training Error (" + str(models[0][0].prop_unit) + ")")
+        ax.set_ylabel("Absolute Training Error (" + latexify(str(models[0][0].prop_unit)) + ")")
         # Populate data
         for nn in range(len(models)):
             train_error = np.array(
@@ -336,7 +338,7 @@ def jackknife_cv_conv_est(dir_expr):
     n_dim = int(train_model_list[-1].split("/")[-1].split("_")[2])
 
     models = [
-        Model(train_file, test_file)
+        ModelRegressor(train_file, test_file)
         for train_file, test_file in zip(train_model_list, test_model_list)
     ]
 
diff --git a/src/python/postprocess/plotting/__init__.py b/src/python/postprocess/plotting/__init__.py
index e69de29b..9d6df417 100644
--- a/src/python/postprocess/plotting/__init__.py
+++ b/src/python/postprocess/plotting/__init__.py
@@ -0,0 +1 @@
+from cpp_sisso.postprocess.plotting.config import *
diff --git a/src/python/postprocess/plotting/default.mpstyle b/src/python/postprocess/plotting/default.mplstyle
similarity index 100%
rename from src/python/postprocess/plotting/default.mpstyle
rename to src/python/postprocess/plotting/default.mplstyle
diff --git a/src/python/postprocess/plotting/utils.py b/src/python/postprocess/plotting/utils.py
new file mode 100644
index 00000000..cfe88845
--- /dev/null
+++ b/src/python/postprocess/plotting/utils.py
@@ -0,0 +1,46 @@
+from matplotlib.patches import PathPatch
+import numpy as np
+
+def adjust_box_widths(ax, fac):
+    """
+    Adjust the widths of a seaborn-generated boxplot.
+    """
+    # iterating through axes artists:
+    for c in ax.get_children():
+
+        # searching for PathPatches
+        if isinstance(c, PathPatch):
+            # getting current width of box:
+            p = c.get_path()
+            verts = p.vertices
+            verts_sub = verts[:-1]
+            xmin = np.min(verts_sub[:, 0])
+            xmax = np.max(verts_sub[:, 0])
+            xmid = 0.5 * (xmin + xmax)
+            xhalf = 0.5 * (xmax - xmin)
+
+            # setting new width of box
+            xmin_new = xmid - fac * xhalf
+            xmax_new = xmid + fac * xhalf
+            verts_sub[verts_sub[:, 0] == xmin, 0] = xmin_new
+            verts_sub[verts_sub[:, 0] == xmax, 0] = xmax_new
+
+            # setting new width of median line
+            for l in ax.lines:
+                if np.all(l.get_xdata() == [xmin, xmax]):
+                    l.set_xdata([xmin_new, xmax_new])
+
+def latexify(s):
+    """Convert a string s into a latex string"""
+    power_split = s.split("^")
+    if len(power_split) == 1:
+        return s
+
+    power_split[0] += "$"
+    for ps in power_split[1:]:
+        unit_end = ps.split(" ")
+        unit_end[0] = "{" + unit_end[0] + "}$"
+        unit_end[-1] += "$"
+        ps = " ".join(unit_end)
+    print("^".join(power_split)[:-1])
+    return "^".join(power_split)[:-1]
\ No newline at end of file
diff --git a/src/python/postprocess/utils.py b/src/python/postprocess/utils.py
new file mode 100644
index 00000000..d3d42488
--- /dev/null
+++ b/src/python/postprocess/utils.py
@@ -0,0 +1,41 @@
+from cpp_sisso import ModelRegressor, ModelClassifier
+from glob import glob
+import numpy as np
+
+def sort_model_file_key(s):
+    """Function to determine the order of model files to import
+
+    Args:
+        s (str): filename
+
+    Return:
+        int: key to sort model files
+    """
+    return s.split("/")[-1].split("_")[2]
+
+
+def get_models(dir_expr):
+    """From a regular expression generate a list of models
+
+    Args:
+        dir_expr (str): Base expression to find the models
+
+    Return:
+        list of Models: Models represented by the expression
+    """
+    train_model_list = sorted(
+        glob(dir_expr + "/models/train_*_model_0.dat"), key=sort_model_file_key
+    )
+    test_model_list = sorted(
+        glob(dir_expr + "/models/test_*_model_0.dat"), key=sort_model_file_key
+    )
+    n_dim = int(train_model_list[-1].split("/")[-1].split("_")[2])
+    if len(test_model_list) > 0:
+        models =  [
+            ModelRegressor(train_file, test_file)
+            for train_file, test_file in zip(train_model_list, test_model_list)
+        ]
+    else:
+        models =  [ModelRegressor(train_file) for train_file in train_model_list]
+
+    return np.array(models).reshape((n_dim, len(models) // n_dim))
-- 
GitLab