Commit 2c4a0abf authored by lucas_miranda's avatar lucas_miranda
Browse files

Integrated autoencoder fitting to coordinates class in data.py

parent 0ff679af
......@@ -24,6 +24,7 @@ from sklearn.impute import SimpleImputer
from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
from tqdm import tqdm
import deepof.models
import deepof.pose_utils
import deepof.utils
import deepof.visuals
......@@ -777,8 +778,67 @@ class coordinates:
tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims
)
def gmvae_embedding(self):
pass
@staticmethod
def deep_unsupervised_embedding(
preprocessed_object: np.array,
encoding_size: int,
batch_size: int = 256,
cp: bool = True,
hparams: dict = None,
kl_warmup: int = 0,
loss: str = "ELBO",
mmd_warmup: int = 0,
montecarlo_kl: int = 10,
n_components: int = 25,
outpath: str = ".",
phenotype_class: float = 0,
predictor: float = 0,
pretrained: str = False,
variational: bool = True,
):
# Load all
X_train, y_train, X_val, y_val = preprocessed_object
if not variational:
encoder, decoder, ae = deepof.models.SEQ_2_SEQ_AE(
({} if hparams is None else hparams)
).build(X_train.shape)
return_list = (encoder, decoder, ae)
else:
(
encoder,
generator,
grouper,
ae,
kl_warmup_callback,
mmd_warmup_callback,
) = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams=({} if hparams is None else hparams),
batch_size=batch_size,
compile_model=True,
encoding=encoding_size,
kl_warmup_epochs=kl_warmup,
loss=loss,
mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl,
neuron_control=False,
number_of_components=n_components,
overlap_loss=False,
phenotype_prediction=phenotype_class,
predictor=predictor,
).build(
X_train.shape
)
return_list = (encoder, generator, grouper, ae)
if pretrained:
ae.load_weights(pretrained)
return return_list
# returns a trained tensorflow model
return ae
class table_dict(dict):
......@@ -1033,7 +1093,7 @@ class table_dict(dict):
if self._propagate_labels:
y_train = y_train[shuffle_train]
return X_train, y_train, X_test, y_test
return X_train, y_train, np.array(X_test), np.array(y_test)
def random_projection(
self, n_components: int = None, sample: int = 1000
......
......@@ -521,7 +521,6 @@ class SEQ_2_SEQ_GMVAE:
def build(self, input_shape: Tuple):
"""Builds the tf.keras model"""
print(input_shape)
# Instanciate prior
self.get_prior()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment