Commit e7fd0a08 authored by lucas_miranda's avatar lucas_miranda
Browse files

Removed outdated dead_neuron_control layer

parent da9d2393
......@@ -587,23 +587,3 @@ class Cluster_overlap(Layer):
self.add_loss(-intercomponent_mmd, inputs=[target])
return target
class Dead_neuron_control(Layer):
"""
Identity layer that adds latent space and clustering stats
to the metrics compiled by the model
"""
def __init__(self, *args, **kwargs):
super(Dead_neuron_control, self).__init__(*args, **kwargs)
# noinspection PyMethodOverriding
def call(self, target, **kwargs):
"""Updates Layer's call method"""
# Adds metric that monitors dead neurons in the latent space
self.add_metric(
tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons"
)
return target
......@@ -253,7 +253,6 @@ class SEQ_2_SEQ_GMVAE:
mmd_annealing_mode: str = "sigmoid",
mmd_warmup_epochs: int = 20,
montecarlo_kl: int = 1,
neuron_control: bool = False,
number_of_components: int = 1,
overlap_loss: float = 0.0,
next_sequence_prediction: float = 0.0,
......@@ -285,7 +284,6 @@ class SEQ_2_SEQ_GMVAE:
self.mc_kl = montecarlo_kl
self.mmd_annealing_mode = mmd_annealing_mode
self.mmd_warmup = mmd_warmup_epochs
self.neuron_control = neuron_control
self.number_of_components = number_of_components
self.optimizer = Nadam(lr=self.learn_rate, clipvalue=self.clipvalue)
self.overlap_loss = overlap_loss
......@@ -628,10 +626,6 @@ class SEQ_2_SEQ_GMVAE:
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
# Identity layer controlling for dead neurons in the Gaussian Mixture posterior
if self.neuron_control:
z_gauss = deepof.model_utils.Dead_neuron_control()(z_gauss)
if self.overlap_loss:
z_gauss = deepof.model_utils.Cluster_overlap(
self.ENCODING,
......
......@@ -164,13 +164,6 @@ parser.add_argument(
default=10,
type=int,
)
parser.add_argument(
"--neuron-control",
"-nc",
help="If True, adds the proportion of dead neurons in the latent space as a metric",
type=str2bool,
default=False,
)
parser.add_argument(
"--output-path",
"-o",
......@@ -274,7 +267,6 @@ loss = args.loss
mmd_annealing_mode = args.mmd_annealing_mode
mmd_wu = args.mmd_warmup
mc_kl = args.montecarlo_kl
neuron_control = args.neuron_control
output_path = os.path.join(args.output_path)
overlap_loss = args.overlap_loss
next_sequence_prediction = float(args.next_sequence_prediction)
......
......@@ -411,7 +411,6 @@ def autoencoder_fitting(
mmd_annealing_mode=mmd_annealing_mode,
mmd_warmup_epochs=mmd_warmup,
montecarlo_kl=montecarlo_kl,
neuron_control=False,
number_of_components=n_components,
overlap_loss=False,
next_sequence_prediction=next_sequence_prediction,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment