Commit 358b099b authored by lucas_miranda's avatar lucas_miranda
Browse files

Added nose2body to rule_based_annotation()

parent fb591fe6
......@@ -367,10 +367,10 @@ class MMDiscrepancyLayer(Layer):
return z
class Gaussian_mixture_overlap(Layer):
class Cluster_overlap(Layer):
"""
Identity layer that measures the overlap between the components of the latent Gaussian Mixture
using a specified metric (MMD, Wasserstein, Fischer-Rao)
using a specified metric (KNN-purity, MMD)
"""
def __init__(self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs):
......@@ -378,7 +378,7 @@ class Gaussian_mixture_overlap(Layer):
self.n_components = n_components
self.loss = loss
self.samples = samples
super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)
super(Cluster_overlap, self).__init__(*args, **kwargs)
def get_config(self): # pragma: no cover
"""Updates Constraint metadata"""
......
......@@ -334,15 +334,15 @@ class SEQ_2_SEQ_GMVAE:
"""Sets the default parameters for the model. Overwritable with a dictionary"""
defaults = {
"bidirectional_merge": "concat",
"bidirectional_merge": "ave",
"clipvalue": 1.0,
"dense_activation": "relu",
"dense_layers_per_branch": 5,
"dropout_rate": 1e-3,
"learning_rate": 1e-3,
"units_conv": 32,
"dense_layers_per_branch": 1,
"dropout_rate": 0.15,
"learning_rate": 1e-4,
"units_conv": 64,
"units_dense2": 32,
"units_lstm": 300,
"units_lstm": 128,
}
for k, v in params.items():
......@@ -599,7 +599,7 @@ class SEQ_2_SEQ_GMVAE:
z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
if self.overlap_loss:
z_gauss = deepof.model_utils.Gaussian_mixture_overlap(
z_gauss = deepof.model_utils.Cluster_overlap(
self.ENCODING,
self.number_of_components,
loss=self.overlap_loss,
......
......@@ -222,7 +222,7 @@ hypertun_trials = args.hpt_trials
encoding_size = args.encoding_size
exclude_bodyparts = [i for i in args.exclude_bodyparts.split(",") if i]
gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters
hparams = args.hyperparameters if args.hyperparameters is not None else {}
input_type = args.input_type
k = args.components
kl_wu = args.kl_warmup
......@@ -261,7 +261,6 @@ assert input_type in [
], "Invalid input type. Type python model_training.py -h for help."
# Loads model hyperparameters and treatment conditions, if available
hparams = load_hparams(hparams)
treatment_dict = load_treatments(train_path)
# Logs hyperparameters if specified on the --logparam CLI argument
......@@ -352,7 +351,7 @@ if not tune:
(X_train, y_train, X_val, y_val),
batch_size=batch_size,
encoding_size=encoding_size,
hparams=hparams,
hparams={},
kl_warmup=kl_wu,
log_history=True,
log_hparams=True,
......
......@@ -46,28 +46,6 @@ class CustomStopper(tf.keras.callbacks.EarlyStopping):
super().on_epoch_end(epoch, logs)
def load_hparams(hparams):
"""Loads hyperparameters from a custom dictionary pickled on disc.
Thought to be used with the output of hyperparameter_tuning.py"""
if hparams is not None:
with open(hparams, "rb") as handle:
hparams = pickle.load(handle)
else:
hparams = {
"bidirectional_merge": "ave",
"clipvalue": 1.0,
"dense_activation": "relu",
"dense_layers_per_branch": 1,
"dropout_rate": 1e-3,
"learning_rate": 1e-3,
"units_conv": 160,
"units_dense2": 120,
"units_lstm": 300,
}
return hparams
def load_treatments(train_path):
"""Loads a dictionary containing the treatments per individual,
to be loaded as metadata in the coordinates class"""
......
......@@ -21,24 +21,6 @@ import os
import tensorflow as tf
def test_load_hparams():
assert type(deepof.train_utils.load_hparams(None)) == dict
assert (
type(
deepof.train_utils.load_hparams(
os.path.join(
"tests",
"test_examples",
"test_single_topview",
"Others",
"test_hparams.pkl",
)
)
)
== dict
)
def test_load_treatments():
assert deepof.train_utils.load_treatments(".") is None
assert (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment