Commit 9f79d50e authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented outlier interpolation

parent 7a6bbd3f
......@@ -781,9 +781,8 @@ class coordinates:
@staticmethod
def deep_unsupervised_embedding(
preprocessed_object: Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
encoding_size: int = 4,
batch_size: int = 256,
cp: bool = True,
encoding_size: int = 4,
hparams: dict = None,
kl_warmup: int = 0,
loss: str = "ELBO",
......@@ -794,17 +793,19 @@ class coordinates:
phenotype_class: float = 0,
predictor: float = 0,
pretrained: str = False,
save_checkpoints: bool = True,
variational: bool = True,
) -> Tuple:
"""
Annotates coordinates using an unsupervised autoencoder
Annotates coordinates using an unsupervised autoencoder.
Full implementation in deepof.train_utils.deep_unsupervised_embedding
Parameters:
- preprocessed_object (Tuple[np.ndarray]): tuple containing a preprocessed object (X_train,
y_train, X_test, y_test)
- encoding_size (int): number of dimensions in the latent space of the autoencoder
- batch_size (int): training batch size
- cp (bool): if True, training checkpoints are saved to disk. Useful for debugging,
- save_checkpoints (bool): if True, training checkpoints are saved to disk. Useful for debugging,
but can make training significantly slower
- hparams (dict): dictionary to change architecture hyperparameters of the autoencoders
(see documentation for details)
......@@ -831,51 +832,26 @@ class coordinates:
"""
# Load all
X_train, y_train, X_val, y_val = preprocessed_object
if not variational:
encoder, decoder, ae = deepof.models.SEQ_2_SEQ_AE(
({} if hparams is None else hparams)
).build(X_train.shape)
return_list = (encoder, decoder, ae)
else:
(
encoder,
generator,
grouper,
ae,
kl_warmup_callback,
mmd_warmup_callback,
) = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams=({} if hparams is None else hparams),
batch_size=batch_size,
compile_model=True,
encoding=encoding_size,
kl_warmup_epochs=kl_warmup,
loss=loss,
mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl,
neuron_control=False,
number_of_components=n_components,
overlap_loss=False,
phenotype_prediction=phenotype_class,
predictor=predictor,
).build(
X_train.shape
)
return_list = (encoder, generator, grouper, ae)
if pretrained:
ae.load_weights(pretrained)
return return_list
else:
pass
trained_models = deepof.train_utils.deep_unsupervised_embedding(
preprocessed_object,
batch_size,
encoding_size,
hparams,
kl_warmup,
loss,
mmd_warmup,
montecarlo_kl,
n_components,
outpath,
phenotype_class,
predictor,
pretrained,
save_checkpoints,
variational,
)
# returns a list of trained tensorflow models
return ae
return trained_models
class table_dict(dict):
......
......@@ -137,6 +137,117 @@ def get_callbacks(
return callbacks
def deep_unsupervised_embedding():
"""Implementation function for deepof.data.coordinates.deep_unsupervised_embedding"""
# Load all
X_train, y_train, X_val, y_val = preprocessed_object
# Load callbacks
# To avoid stability issues
tf.keras.backend.clear_session()
run_ID, tensorboard_callback, onecycle, cp_callback = get_callbacks(
X_train=X_train,
batch_size=batch_size,
cp=save_checkpoints,
variational=variational,
phenotype_class=pheno_class,
predictor=predictor,
loss=loss,
logparam=logparam,
outpath=output_path,
)
if not variational:
encoder, decoder, ae = deepof.models.SEQ_2_SEQ_AE(
({} if hparams is None else hparams)
).build(X_train.shape)
return_list = (encoder, decoder, ae)
else:
(
encoder,
generator,
grouper,
ae,
kl_warmup_callback,
mmd_warmup_callback,
) = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams=({} if hparams is None else hparams),
batch_size=batch_size,
compile_model=True,
encoding=encoding_size,
kl_warmup_epochs=kl_warmup,
loss=loss,
mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl,
neuron_control=False,
number_of_components=n_components,
overlap_loss=False,
phenotype_prediction=phenotype_class,
predictor=predictor,
).build(
X_train.shape
)
return_list = (encoder, generator, grouper, ae)
if pretrained:
ae.load_weights(pretrained)
return return_list
else:
if not variational:
ae.fit(
x=X_train,
y=X_train,
epochs=35,
batch_size=batch_size,
verbose=1,
validation_data=(X_val, X_val),
callbacks=[
tensorboard_callback,
cp_callback,
onecycle,
CustomStopper(
monitor="val_loss",
patience=5,
restore_best_weights=True,
start_epoch=max(kl_wu, mmd_wu),
),
],
)
else:
Xs, ys = [X_train], [X_train]
Xvals, yvals = [X_val], [X_val]
if predictor > 0.0:
Xs, ys = X_train[:-1], [X_train[:-1], X_train[1:]]
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if pheno_class > 0.0:
ys += [y_train]
yvals += [y_val]
gmvaep.fit(
x=Xs,
y=ys,
epochs=35,
batch_size=batch_size,
verbose=1,
validation_data=(
Xvals,
yvals,
),
callbacks=callbacks_,
)
return return_list
def tune_search(
data: List[np.array],
encoding_size: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment