Commit 04541053 authored by lucas_miranda's avatar lucas_miranda
Browse files

Reformatted files

parent 1338ba03
......@@ -92,7 +92,7 @@ class SEQ_2_SEQ_AE(HyperModel):
ENCODING,
activation="relu",
kernel_constraint=UnitNorm(axis=1),
activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
activity_regularizer=uncorrelated_features_constraint(3, weightage=1.0),
kernel_initializer=Orthogonal(),
)
......@@ -340,7 +340,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
for k in range(self.number_of_components)
],
),
activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
activity_regularizer=uncorrelated_features_constraint(3, weightage=1.0),
)([z_cat, z_gauss])
if "ELBO" in self.loss:
......
......@@ -172,7 +172,14 @@ class one_cycle_scheduler(tf.keras.callbacks.Callback):
K.set_value(self.model.optimizer.lr, rate)
class UncorrelatedFeaturesConstraint(Constraint):
class uncorrelated_features_constraint(Constraint):
"""
Tensorflow Constraint subclass that forces a layer to have uncorrelated features.
Useful, among others, for auto encoder bottleneck layers
"""
def __init__(self, encoding_dim, weightage=1.0):
self.encoding_dim = encoding_dim
self.weightage = weightage
......
......@@ -92,7 +92,7 @@ class SEQ_2_SEQ_AE:
self.ENCODING,
activation="elu",
kernel_constraint=UnitNorm(axis=1),
activity_regularizer=UncorrelatedFeaturesConstraint(2, weightage=1.0),
activity_regularizer=uncorrelated_features_constraint(2, weightage=1.0),
kernel_initializer=Orthogonal(),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment