diff --git a/src/python/postprocess/__init__.py b/src/python/postprocess/__init__.py index 1bb9f7c3b64d140e97804d23a1d9cf6f05944b3d..fd104116b783bdb9c97969e761d33da40a8db88b 100644 --- a/src/python/postprocess/__init__.py +++ b/src/python/postprocess/__init__.py @@ -3,7 +3,7 @@ from cpp_sisso.postprocess.plotting import plt, config from cpp_sisso.postprocess.plotting.utils import adjust_box_widths, latexify from cpp_sisso.postprocess.utils import get_models from cpp_sisso import ModelRegressor, ModelClassifier - +import toml import numpy as np import pandas as pd @@ -24,6 +24,9 @@ def generate_plot(dir_expr, filename, fig_settings=None): """ # Set up the figure + if isinstance(fig_settings, str): + fig_settings = toml.load(fig_settings) + fig_config = deepcopy(config) if fig_settings: fig_config.update(fig_settings) @@ -53,7 +56,7 @@ def generate_plot(dir_expr, filename, fig_settings=None): if fig_config["plot_options"]["type"] == "split": # Set yaxis label - ax.set_ylabel("Absolute Error (" + str(models[0][0].prop_unit) + ")") + ax.set_ylabel("Absolute Error (" + latexify(str(models[0][0].prop_unit)) + ")") # Populate data for nn in range(len(models)): @@ -137,7 +140,7 @@ def generate_plot(dir_expr, filename, fig_settings=None): plot_settings["box_edge_colors"] = [fig_config["colors"]["box_train_edge"]] elif fig_config["plot_options"]["type"] == "test": # Set yaxis label - ax.set_ylabel("Absolute Test Error (" + str(models[0][0].prop_unit) + ")") + ax.set_ylabel("Absolute Test Error (" + latexify(str(models[0][0].prop_unit)) + ")") # Populate data for nn in range(len(models)): diff --git a/src/python/postprocess/plotting/utils.py b/src/python/postprocess/plotting/utils.py index cfe88845760ef5f46b6adbe449f9f3e3652432ee..0b515816c3f378c91704f1c83157a7e73c6ea7e7 100644 --- a/src/python/postprocess/plotting/utils.py +++ b/src/python/postprocess/plotting/utils.py @@ -33,14 +33,19 @@ def adjust_box_widths(ax, fac): def latexify(s): """Convert a string s into a latex string""" power_split = s.split("^") + + print(power_split) + if len(power_split) == 1: return s power_split[0] += "$" - for ps in power_split[1:]: - unit_end = ps.split(" ") + for pp in range(1, len(power_split)): + unit_end = power_split[pp].split(" ") unit_end[0] = "{" + unit_end[0] + "}$" unit_end[-1] += "$" - ps = " ".join(unit_end) + power_split[pp] = " ".join(unit_end) + print("^".join(power_split)[:-1]) - return "^".join(power_split)[:-1] \ No newline at end of file + + return "^".join(power_split)[:-1]