Commit ff091f9a authored by lucas_miranda's avatar lucas_miranda
Browse files

Added latent regularization control to deepof.data.coordinates.deep_unsupervised_embedding()

parent 33a4fa0a
Pipeline #93315 canceled with stage
in 9 minutes and 30 seconds
......@@ -183,14 +183,6 @@ parser.add_argument(
type=float,
default=0.99,
)
parser.add_argument(
"--stability-check",
"-s",
help="Sets the number of times that the model is trained and initialised. "
"If greater than 1 (the default), saves the cluster assignments to a dataframe on disk",
type=int,
default=1,
)
parser.add_argument("--train-path", "-tp", help="set training set path", type=str)
parser.add_argument(
"--val-num",
......@@ -243,7 +235,6 @@ output_path = os.path.join(args.output_path)
overlap_loss = args.overlap_loss
pheno_class = float(args.phenotype_classifier)
predictor = float(args.predictor)
runs = args.stability_check
smooth_alpha = args.smooth_alpha
train_path = os.path.abspath(args.train_path)
tune = args.hyperparameter_tuning
......@@ -287,7 +278,7 @@ project_coords = project(
animal_ids=tuple([animal_id]),
arena="circular",
arena_dims=tuple([arena_dims]),
enable_iterative_imputation=True,
enable_iterative_imputation=False,
exclude_bodyparts=exclude_bodyparts,
exp_conditions=treatment_dict,
path=train_path,
......@@ -357,6 +348,10 @@ print("Done!")
# as many times as specified by runs
if not tune:
print(latent_reg)
print(("categorical" in latent_reg))
print(("variance" in latent_reg))
trained_models = project_coords.deep_unsupervised_embedding(
(X_train, y_train, X_val, y_val),
batch_size=batch_size,
......
......@@ -430,7 +430,7 @@ def autoencoder_fitting(
)
if save_weights:
ae.save_weights("{}_final_weights.h5".format(run_ID))
ae.save_weights("{}{}_final_weights.h5".format(output_path, run_ID))
if log_hparams:
# Logparams to tensorboard
......
......@@ -99,9 +99,8 @@ rule latent_regularization_experiments:
"--encoding-size {wildcards.encs} "
"--batch-size 256 "
"--window-size 24 "
"--window-step 6 "
"--window-step 12 "
"--exclude-bodyparts Tail_base,Tail_1,Tail_2,Tail_tip "
"--stability-check 3 "
"--output-path {outpath}latent_regularization_experiments"
#
#
......
......@@ -35,7 +35,7 @@ def test_SEQ_2_SEQ_AE_hypermodel_build(input_shape):
@given(
encoding_size=st.integers(min_value=2, max_value=16),
loss=st.one_of(st.just("ELBO"), st.just("MMD"), st.just("ELBO+MMD")),
number_of_components=st.integers(min_value=1, max_value=5),
noumber_of_components=st.integers(min_value=1, max_value=5),
)
def test_SEQ_2_SEQ_GMVAE_hypermodel_build(
encoding_size,
......
......@@ -75,7 +75,16 @@ def test_get_callbacks(
loss,
):
runID, tbc, cycle1c, cpc = deepof.train_utils.get_callbacks(
X_train, batch_size, True, variational, pheno_class, predictor, loss, None
X_train,
batch_size,
True,
variational,
pheno_class,
predictor,
loss,
True,
True,
None,
)
assert type(runID) == str
assert type(tbc) == tf.keras.callbacks.TensorBoard
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment