Commit 2c4a0abf authored by lucas_miranda's avatar lucas_miranda
Browse files

Integrated autoencoder fitting to coordinates class in data.py

parent 0ff679af
...@@ -24,6 +24,7 @@ from sklearn.impute import SimpleImputer ...@@ -24,6 +24,7 @@ from sklearn.impute import SimpleImputer
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder from sklearn.preprocessing import MinMaxScaler, StandardScaler, LabelEncoder
from tqdm import tqdm from tqdm import tqdm
import deepof.models
import deepof.pose_utils import deepof.pose_utils
import deepof.utils import deepof.utils
import deepof.visuals import deepof.visuals
...@@ -777,8 +778,67 @@ class coordinates: ...@@ -777,8 +778,67 @@ class coordinates:
tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims
) )
def gmvae_embedding(self): @staticmethod
pass def deep_unsupervised_embedding(
preprocessed_object: np.array,
encoding_size: int,
batch_size: int = 256,
cp: bool = True,
hparams: dict = None,
kl_warmup: int = 0,
loss: str = "ELBO",
mmd_warmup: int = 0,
montecarlo_kl: int = 10,
n_components: int = 25,
outpath: str = ".",
phenotype_class: float = 0,
predictor: float = 0,
pretrained: str = False,
variational: bool = True,
):
# Load all
X_train, y_train, X_val, y_val = preprocessed_object
if not variational:
encoder, decoder, ae = deepof.models.SEQ_2_SEQ_AE(
({} if hparams is None else hparams)
).build(X_train.shape)
return_list = (encoder, decoder, ae)
else:
(
encoder,
generator,
grouper,
ae,
kl_warmup_callback,
mmd_warmup_callback,
) = deepof.models.SEQ_2_SEQ_GMVAE(
architecture_hparams=({} if hparams is None else hparams),
batch_size=batch_size,
compile_model=True,
encoding=encoding_size,
kl_warmup_epochs=kl_warmup,
loss=loss,
mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl,
neuron_control=False,
number_of_components=n_components,
overlap_loss=False,
phenotype_prediction=phenotype_class,
predictor=predictor,
).build(
X_train.shape
)
return_list = (encoder, generator, grouper, ae)
if pretrained:
ae.load_weights(pretrained)
return return_list
# returns a trained tensorflow model
return ae
class table_dict(dict): class table_dict(dict):
...@@ -1033,7 +1093,7 @@ class table_dict(dict): ...@@ -1033,7 +1093,7 @@ class table_dict(dict):
if self._propagate_labels: if self._propagate_labels:
y_train = y_train[shuffle_train] y_train = y_train[shuffle_train]
return X_train, y_train, X_test, y_test return X_train, y_train, np.array(X_test), np.array(y_test)
def random_projection( def random_projection(
self, n_components: int = None, sample: int = 1000 self, n_components: int = None, sample: int = 1000
......
...@@ -521,7 +521,6 @@ class SEQ_2_SEQ_GMVAE: ...@@ -521,7 +521,6 @@ class SEQ_2_SEQ_GMVAE:
def build(self, input_shape: Tuple): def build(self, input_shape: Tuple):
"""Builds the tf.keras model""" """Builds the tf.keras model"""
print(input_shape)
# Instanciate prior # Instanciate prior
self.get_prior() self.get_prior()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment