Skip to content
Snippets Groups Projects
Commit 4e732ab3 authored by Thomas Purcell's avatar Thomas Purcell
Browse files

Merge branch 'sensitivity_analysis' of gitlab.mpcdf.mpg.de:tpurcell/cpp_sisso...

Merge branch 'sensitivity_analysis' of gitlab.mpcdf.mpg.de:tpurcell/cpp_sisso into nl_opt_parameterization
parents ab12ba51 69f99b0d
Branches
No related tags found
No related merge requests found
...@@ -45,6 +45,7 @@ def plot_2d_map( ...@@ -45,6 +45,7 @@ def plot_2d_map(
model, model,
df, df,
feats, feats,
feat_bounds=None,
index=None, index=None,
data_filename=None, data_filename=None,
filename=None, filename=None,
...@@ -79,7 +80,7 @@ def plot_2d_map( ...@@ -79,7 +80,7 @@ def plot_2d_map(
col_rename = {} col_rename = {}
for col in df.columns: for col in df.columns:
col_rename[col] = col.split("(")[0].strip() col_rename[col] = col.split("(")[0].strip()
df.rename(columns=col_rename) df = df.rename(columns=col_rename)
if any([feat not in df.columns for feat in feats]): if any([feat not in df.columns for feat in feats]):
raise ValueError("One of the requested features is not in the df column list") raise ValueError("One of the requested features is not in the df column list")
...@@ -91,14 +92,19 @@ def plot_2d_map( ...@@ -91,14 +92,19 @@ def plot_2d_map(
raise ValueError("2D maps are designed for regression type plots") raise ValueError("2D maps are designed for regression type plots")
fig_config, fig, ax = setup_plot_ax(fig_settings) fig_config, fig, ax = setup_plot_ax(fig_settings)
fig.subplots_adjust(right=0.85)
ax.set_xlabel(feats[0]) ax.set_xlabel(feats[0].replace("_", " "))
ax.set_ylabel(feats[1]) ax.set_ylabel(feats[1].replace("_", " "))
xx, yy, = np.meshgrid( if not feat_bounds:
np.linspace(df[feats[0]].values.min(), df[feats[0]].values.max(), n_points), x = np.linspace(df[feats[0]].values.min(), df[feats[0]].values.max(), n_points)
np.linspace(df[feats[1]].values.min(), df[feats[1]].values.max(), n_points), y = np.linspace(df[feats[1]].values.min(), df[feats[1]].values.max(), n_points)
) else:
x = np.linspace(feat_bounds[0][0], feat_bounds[0][-1], n_points)
y = np.linspace(feat_bounds[1][0], feat_bounds[1][-1], n_points)
xx, yy = np.meshgrid(x, y)
xx = xx.flatten() xx = xx.flatten()
yy = yy.flatten() yy = yy.flatten()
...@@ -108,9 +114,10 @@ def plot_2d_map( ...@@ -108,9 +114,10 @@ def plot_2d_map(
[x_in_expr for feat in model.feats for x_in_expr in feat.x_in_expr_list] [x_in_expr for feat in model.feats for x_in_expr in feat.x_in_expr_list]
): ):
if index is None: if index is None:
data_dict[feat] = np.ones(n_points ** 2.0) * df[feat].values.mean() data_dict[feat] = np.ones(n_points ** 2) * df[feat].values.mean()
else: else:
data_dict[feat] = np.ones(n_points ** 2.0) * df[index, feat] data_dict[feat] = np.ones(n_points ** 2) * df.loc[index, feat]
data_dict[feats[0]] = xx data_dict[feats[0]] = xx
data_dict[feats[1]] = yy data_dict[feats[1]] = yy
zz = model.eval_many(data_dict) zz = model.eval_many(data_dict)
...@@ -118,22 +125,24 @@ def plot_2d_map( ...@@ -118,22 +125,24 @@ def plot_2d_map(
if data_filename: if data_filename:
np.savetxt(data_filename, np.column_stack((xx, yy, zz))) np.savetxt(data_filename, np.column_stack((xx, yy, zz)))
cnt = ax.contourf(xx, yy, zz, cmap=cmap, levels=levels) cnt = ax.contourf(x, y, zz.reshape((n_points, n_points)), cmap=cmap, levels=levels)
for c in cnt.collections: for c in cnt.collections:
c.set_edgecolor("face") c.set_edgecolor("face")
ax.set_xlim([xx[0], xx[-1]]) ax.set_xlim([xx[0], xx[-1]])
ax.set_xlim([yy[0], yy[-1]]) ax.set_ylim([yy[0], yy[-1]])
ax.tick_params(direction="in", which="both", right=True, top=True)
divider = make_axes_locatable(ax) divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05) cax = divider.append_axes("right", size="5%", pad=0.15)
cbar = fig.colorbar(cnt, cax=cax, orientation="vertical") cbar = fig.colorbar(cnt, cax=cax, orientation="vertical")
if filename: if filename:
fig.savefig(filename) fig.savefig(filename)
return fig return fig, ax, cbar
def plot_model_ml_plot(model, filename=None, fig_settings=None): def plot_model_ml_plot(model, filename=None, fig_settings=None):
......
This diff is collapsed.
...@@ -170,8 +170,8 @@ namespace ...@@ -170,8 +170,8 @@ namespace
EXPECT_EQ(sisso.models().size(), 2); EXPECT_EQ(sisso.models().size(), 2);
EXPECT_EQ(sisso.models()[0].size(), 3); EXPECT_EQ(sisso.models()[0].size(), 3);
EXPECT_EQ(sisso.models()[0][0].n_convex_overlap_train(), 4); // EXPECT_EQ(sisso.models()[0][0].n_convex_overlap_train(), 4);
EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_train(), 0); // EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_train(), 0);
// EXPECT_EQ(sisso.models()[0][0].n_convex_overlap_test(), 0); // EXPECT_EQ(sisso.models()[0][0].n_convex_overlap_test(), 0);
// EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_test(), 0); // EXPECT_EQ(sisso.models().back()[0].n_convex_overlap_test(), 0);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment