diff --git a/src/python/postprocess/__init__.py b/src/python/postprocess/__init__.py index d3aa9e5bcdd77c8487d740683f814751edaa5409..1bb9f7c3b64d140e97804d23a1d9cf6f05944b3d 100644 --- a/src/python/postprocess/__init__.py +++ b/src/python/postprocess/__init__.py @@ -1,7 +1,9 @@ from cpp_sisso.postprocess.check_cv_convergence import jackknife_cv_conv_est from cpp_sisso.postprocess.plotting import plt, config -from cpp_sisso.postprocess.plotting.utils import adjust_box_widths +from cpp_sisso.postprocess.plotting.utils import adjust_box_widths, latexify from cpp_sisso.postprocess.utils import get_models +from cpp_sisso import ModelRegressor, ModelClassifier + import numpy as np import pandas as pd @@ -107,7 +109,7 @@ def generate_plot(dir_expr, filename, fig_settings=None): ] elif fig_config["plot_options"]["type"] == "train": # Set yaxis label - ax.set_ylabel("Absolute Training Error (" + str(models[0][0].prop_unit) + ")") + ax.set_ylabel("Absolute Training Error (" + latexify(str(models[0][0].prop_unit)) + ")") # Populate data for nn in range(len(models)): train_error = np.array( @@ -336,7 +338,7 @@ def jackknife_cv_conv_est(dir_expr): n_dim = int(train_model_list[-1].split("/")[-1].split("_")[2]) models = [ - Model(train_file, test_file) + ModelRegressor(train_file, test_file) for train_file, test_file in zip(train_model_list, test_model_list) ] diff --git a/src/python/postprocess/plotting/__init__.py b/src/python/postprocess/plotting/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..9d6df417f5c6c8e6e442c93d889b0da5123d120d 100644 --- a/src/python/postprocess/plotting/__init__.py +++ b/src/python/postprocess/plotting/__init__.py @@ -0,0 +1 @@ +from cpp_sisso.postprocess.plotting.config import * diff --git a/src/python/postprocess/plotting/default.mpstyle b/src/python/postprocess/plotting/default.mplstyle similarity index 100% rename from src/python/postprocess/plotting/default.mpstyle rename to src/python/postprocess/plotting/default.mplstyle diff --git a/src/python/postprocess/plotting/utils.py b/src/python/postprocess/plotting/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe88845760ef5f46b6adbe449f9f3e3652432ee --- /dev/null +++ b/src/python/postprocess/plotting/utils.py @@ -0,0 +1,46 @@ +from matplotlib.patches import PathPatch +import numpy as np + +def adjust_box_widths(ax, fac): + """ + Adjust the widths of a seaborn-generated boxplot. + """ + # iterating through axes artists: + for c in ax.get_children(): + + # searching for PathPatches + if isinstance(c, PathPatch): + # getting current width of box: + p = c.get_path() + verts = p.vertices + verts_sub = verts[:-1] + xmin = np.min(verts_sub[:, 0]) + xmax = np.max(verts_sub[:, 0]) + xmid = 0.5 * (xmin + xmax) + xhalf = 0.5 * (xmax - xmin) + + # setting new width of box + xmin_new = xmid - fac * xhalf + xmax_new = xmid + fac * xhalf + verts_sub[verts_sub[:, 0] == xmin, 0] = xmin_new + verts_sub[verts_sub[:, 0] == xmax, 0] = xmax_new + + # setting new width of median line + for l in ax.lines: + if np.all(l.get_xdata() == [xmin, xmax]): + l.set_xdata([xmin_new, xmax_new]) + +def latexify(s): + """Convert a string s into a latex string""" + power_split = s.split("^") + if len(power_split) == 1: + return s + + power_split[0] += "$" + for ps in power_split[1:]: + unit_end = ps.split(" ") + unit_end[0] = "{" + unit_end[0] + "}$" + unit_end[-1] += "$" + ps = " ".join(unit_end) + print("^".join(power_split)[:-1]) + return "^".join(power_split)[:-1] \ No newline at end of file diff --git a/src/python/postprocess/utils.py b/src/python/postprocess/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3d42488e30a1d79250ce111b06e3c82fac35ba5 --- /dev/null +++ b/src/python/postprocess/utils.py @@ -0,0 +1,41 @@ +from cpp_sisso import ModelRegressor, ModelClassifier +from glob import glob +import numpy as np + +def sort_model_file_key(s): + """Function to determine the order of model files to import + + Args: + s (str): filename + + Return: + int: key to sort model files + """ + return s.split("/")[-1].split("_")[2] + + +def get_models(dir_expr): + """From a regular expression generate a list of models + + Args: + dir_expr (str): Base expression to find the models + + Return: + list of Models: Models represented by the expression + """ + train_model_list = sorted( + glob(dir_expr + "/models/train_*_model_0.dat"), key=sort_model_file_key + ) + test_model_list = sorted( + glob(dir_expr + "/models/test_*_model_0.dat"), key=sort_model_file_key + ) + n_dim = int(train_model_list[-1].split("/")[-1].split("_")[2]) + if len(test_model_list) > 0: + models = [ + ModelRegressor(train_file, test_file) + for train_file, test_file in zip(train_model_list, test_model_list) + ] + else: + models = [ModelRegressor(train_file) for train_file in train_model_list] + + return np.array(models).reshape((n_dim, len(models) // n_dim))