Commit f4b22d1c authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented outlier interpolation

parent 9f79d50e
......@@ -28,6 +28,7 @@ import deepof.models
import deepof.pose_utils
import deepof.utils
import deepof.visuals
import deepof.train_utils
import matplotlib.pyplot as plt
import numpy as np
import os
......
......@@ -140,13 +140,13 @@ def get_callbacks(
def deep_unsupervised_embedding():
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
# Load all
# Load data
X_train, y_train, X_val, y_val = preprocessed_object
# Load callbacks
# To avoid stability issues
tf.keras.backend.clear_session()
# Load callbacks
run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks(
X_train=X_train,
batch_size=batch_size,
......@@ -159,6 +159,7 @@ def deep_unsupervised_embedding():
outpath=output_path,
)
# Build models
if not variational:
encoder, decoder, ae = deepof.models.SEQ_2_SEQ_AE(
({} if hparams is None else hparams)
......@@ -193,6 +194,7 @@ def deep_unsupervised_embedding():
return_list = (encoder, generator, grouper, ae)
if pretrained:
# If pretrained models are specified, load weights and return
ae.load_weights(pretrained)
return return_list
......@@ -221,6 +223,25 @@ def deep_unsupervised_embedding():
else:
callbacks_ = [
tensorboard_callback,
# cp_callback,
onecycle,
CustomStopper(
monitor="val_loss",
patience=5,
restore_best_weights=True,
start_epoch=max(kl_wu, mmd_wu),
),
]
if "ELBO" in loss and kl_wu > 0:
# noinspection PyUnboundLocalVariable
callbacks_.append(kl_warmup_callback)
if "MMD" in loss and mmd_wu > 0:
# noinspection PyUnboundLocalVariable
callbacks_.append(mmd_warmup_callback)
Xs, ys = [X_train], [X_train]
Xvals, yvals = [X_val], [X_val]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment