Skip to content
Snippets Groups Projects
Commit 4e732ab3 authored by Thomas Purcell's avatar Thomas Purcell
Browse files

Merge branch 'sensitivity_analysis' of gitlab.mpcdf.mpg.de:tpurcell/cpp_sisso...

Merge branch 'sensitivity_analysis' of gitlab.mpcdf.mpg.de:tpurcell/cpp_sisso into nl_opt_parameterization
parents ab12ba51 69f99b0d
Branches
No related tags found
No related merge requests found
......@@ -45,6 +45,7 @@ def plot_2d_map(
model,
df,
feats,
feat_bounds=None,
index=None,
data_filename=None,
filename=None,
......@@ -79,7 +80,7 @@ def plot_2d_map(
col_rename = {}
for col in df.columns:
col_rename[col] = col.split("(")[0].strip()
df.rename(columns=col_rename)
df = df.rename(columns=col_rename)
if any([feat not in df.columns for feat in feats]):
raise ValueError("One of the requested features is not in the df column list")
......@@ -91,14 +92,19 @@ def plot_2d_map(
raise ValueError("2D maps are designed for regression type plots")
fig_config, fig, ax = setup_plot_ax(fig_settings)
fig.subplots_adjust(right=0.85)
ax.set_xlabel(feats[0])
ax.set_ylabel(feats[1])
ax.set_xlabel(feats[0].replace("_", " "))
ax.set_ylabel(feats[1].replace("_", " "))
xx, yy, = np.meshgrid(
np.linspace(df[feats[0]].values.min(), df[feats[0]].values.max(), n_points),
np.linspace(df[feats[1]].values.min(), df[feats[1]].values.max(), n_points),
)
if not feat_bounds:
x = np.linspace(df[feats[0]].values.min(), df[feats[0]].values.max(), n_points)
y = np.linspace(df[feats[1]].values.min(), df[feats[1]].values.max(), n_points)
else:
x = np.linspace(feat_bounds[0][0], feat_bounds[0][-1], n_points)
y = np.linspace(feat_bounds[1][0], feat_bounds[1][-1], n_points)
xx, yy = np.meshgrid(x, y)
xx = xx.flatten()
yy = yy.flatten()
......@@ -108,9 +114,10 @@ def plot_2d_map(
[x_in_expr for feat in model.feats for x_in_expr in feat.x_in_expr_list]
):
if index is None:
data_dict[feat] = np.ones(n_points ** 2.0) * df[feat].values.mean()
data_dict[feat] = np.ones(n_points ** 2) * df[feat].values.mean()
else:
data_dict[feat] = np.ones(n_points ** 2.0) * df[index, feat]
data_dict[feat] = np.ones(n_points ** 2) * df.loc[index, feat]
data_dict[feats[0]] = xx
data_dict[feats[1]] = yy
zz = model.eval_many(data_dict)
......@@ -118,22 +125,24 @@ def plot_2d_map(
if data_filename:
np.savetxt(data_filename, np.column_stack((xx, yy, zz)))
cnt = ax.contourf(xx, yy, zz, cmap=cmap, levels=levels)
cnt = ax.contourf(x, y, zz.reshape((n_points, n_points)), cmap=cmap, levels=levels)
for c in cnt.collections:
c.set_edgecolor("face")
ax.set_xlim([xx[0], xx[-1]])
ax.set_xlim([yy[0], yy[-1]])
ax.set_ylim([yy[0], yy[-1]])
ax.tick_params(direction="in", which="both", right=True, top=True)
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cax = divider.append_axes("right", size="5%", pad=0.15)
cbar = fig.colorbar(cnt, cax=cax, orientation="vertical")
if filename:
fig.savefig(filename)
return fig
return fig, ax, cbar
def plot_model_ml_plot(model, filename=None, fig_settings=None):
......
This diff is collapsed.
......@@ -170,8 +170,8 @@ namespace
EXPECT_EQ(sisso.models().size(), 2);
EXPECT_EQ(sisso.models()[0].size(), 3);
EXPECT_EQ(sisso.models()[0][0].n_convex_overlap_train(), 4);
EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_train(), 0);
// EXPECT_EQ(sisso.models()[0][0].n_convex_overlap_train(), 4);
// EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_train(), 0);
// EXPECT_EQ(sisso.models()[0][0].n_convex_overlap_test(), 0);
// EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_test(), 0);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment