Commit 454cbfa9 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented weight saving callback in model_training.py

parent 16bef013
......@@ -281,7 +281,10 @@ else:
gmvaep.build(pttest.shape)
predictions.append(encoder.predict(pttest))
reconstructions.append(gmvaep.predict(pttest))
if predictor:
reconstructions.append(gmvaep.predict(pttest)[0])
else:
reconstructions.append(gmvaep.predict(pttest))
print("Building predictions from pretrained models...")
......@@ -292,7 +295,10 @@ for checkpoint in tqdm(checkpoints):
gmvaep.load_weights(checkpoint)
clusters.append(grouper.predict(pttest))
predictions.append(encoder.predict(pttest))
reconstructions.append(gmvaep.predict(pttest))
if predictor:
reconstructions.append(gmvaep.predict(pttest)[0])
else:
reconstructions.append(gmvaep.predict(pttest))
else:
ae.load_weights(checkpoint)
......@@ -349,9 +355,6 @@ dfencs["epoch"] = np.array(
)
dfencs.columns = ["x", "y", "cluster", "epoch"]
dfencs["trajectories"] = np.tile(pttest[:, 6, 1], len(checkpoints) + 1)
print(np.concatenate(reconstructions).shape)
dfencs["reconstructions"] = np.concatenate(reconstructions)[:, 6, 1]
# Cluster membership animated over epochs
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment