Commit b2255f96 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented weight saving callback in model_training.py

parent fefaae76
......@@ -27,7 +27,7 @@ fig = px.scatter(
y="y",
animation_frame="epoch",
color="cluster",
labels={"x": "PCA 1", "y": "PCA 2"},
labels={"x": "LDA 1", "y": "LDA 2"},
width=550,
height=500,
color_discrete_sequence=px.colors.qualitative.T10,
......@@ -37,9 +37,7 @@ fig.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 1
fig.update_xaxes(showgrid=True, range=[-0.2, 1.2])
fig.update_yaxes(showgrid=True, range=[-0.2, 1.2])
fig.update_layout(
{"paper_bgcolor": "rgba(0, 0, 0, 0)"}
)
fig.update_layout({"paper_bgcolor": "rgba(0, 0, 0, 0)"})
# Cluster membership animated over epochs
clust_occur = pd.DataFrame(
......@@ -63,9 +61,7 @@ fig2.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 1
fig2.update_xaxes(showgrid=True)
fig2.update_yaxes(showgrid=True, range=[0, samples / 1.5])
fig2.update_layout(
{"paper_bgcolor": "rgba(0, 0, 0, 0)"}
)
fig2.update_layout({"paper_bgcolor": "rgba(0, 0, 0, 0)"})
# Scatterplot reconstruction over epochs
......@@ -84,9 +80,7 @@ fig3.layout.updatemenus[0].buttons[0].args[1]["frame"]["duration"] = 1
fig3.update_xaxes(showgrid=True, range=[-2, 2])
fig3.update_yaxes(showgrid=True, range=[-2, 2])
fig3.update_layout(
{"paper_bgcolor": "rgba(0, 0, 0, 0)"}
)
fig3.update_layout({"paper_bgcolor": "rgba(0, 0, 0, 0)"})
fig4 = px.line(
data_frame=maedf,
......@@ -94,12 +88,10 @@ fig4 = px.line(
y="mae",
color_discrete_sequence=px.colors.qualitative.T10,
)
#fig4.update_xaxes(showgrid=False)
#fig4.update_yaxes(showgrid=False)
# fig4.update_xaxes(showgrid=False)
# fig4.update_yaxes(showgrid=False)
fig4.update_layout(
{"paper_bgcolor": "rgba(0, 0, 0, 0)"}
)
fig4.update_layout({"paper_bgcolor": "rgba(0, 0, 0, 0)"})
# Combine all three figures in a Dash application
app = dash.Dash(__name__)
......
......@@ -7,6 +7,7 @@ sys.path.insert(1, "../")
from copy import deepcopy
from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import mean_absolute_error
from source.preprocess import *
from source.models import *
......@@ -160,13 +161,13 @@ else:
"learning_rate": 1e-3,
}
with open(
os.path.abspath(
data_path + "/" + [i for i in os.listdir(data_path) if i.endswith(".pickle")][0]
),
"rb",
) as handle:
Treatment_dict = pickle.load(handle)
# with open(
# os.path.abspath(
# data_path + "/" + [i for i in os.listdir(data_path) if i.endswith(".pickle")][0]
# ),
# "rb",
# ) as handle:
# Treatment_dict = pickle.load(handle)
# Which angles to compute?
bp_dict = {
......@@ -204,7 +205,7 @@ DLC_social = project(
arena_dims=[380], # Dimensions of the arena. Just one if it's circular
video_format=".mp4",
table_format=".h5",
exp_conditions=Treatment_dict,
# exp_conditions=Treatment_dict,
)
......@@ -304,10 +305,24 @@ for checkpoint in tqdm(checkpoints):
print("Done!")
print("Reducing latent space to 2 dimensions for dataviz...")
reducer = PCA(n_components=2)
encs = [reducer.fit_transform(i) for i in tqdm(predictions)]
reducer = LinearDiscriminantAnalysis(n_components=2)
encs = []
for i in range(len(checkpoints) + 1):
if i == 0:
clusts = (
np.array([int(i) for i in np.random.uniform(0, k, samples)])
if variational
else np.zeros(samples)
)
encs.append(reducer.fit_transform(predictions[i], clusts))
else:
encs.append(
reducer.fit_transform(predictions[i], np.argmax(clusters[i - 1], axis=1))
)
# As projection direction is difficult to predict in PCA,
# As projection direction is difficult to predict in LDA,
# axes are flipped to maintain subsequent representations
# of the input closer to one another
flip_encs = flip_axes(encs)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment