Commit 7725eb28 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented mmd between Gaussian mixture components to track cluster...

Implemented mmd between Gaussian mixture components to track cluster overlapping. Adding a term to the loss function is an option
parent 2b01ccdd
......@@ -90,6 +90,20 @@ parser.add_argument(
default=16,
type=int,
)
parser.add_argument(
"--overlap-loss",
"-ol",
help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function",
default=False,
type=str2bool
)
parser.add_argument(
"--batch-size",
"-bs",
help="set training batch size. Defaults to 512",
type=int,
default=512
)
args = parser.parse_args()
train_path = os.path.abspath(args.train_path)
......@@ -103,6 +117,8 @@ kl_wu = args.kl_warmup
mmd_wu = args.mmd_warmup
hparams = args.hyperparameters
encoding = args.encoding_size
batch_size = args.batch_size
overlap_loss = args.overlap_loss
if not train_path:
raise ValueError("Set a valid data path for the training to run")
......@@ -372,7 +388,7 @@ if not variational:
x=input_dict_train[input_type],
y=input_dict_train[input_type],
epochs=250,
batch_size=512,
batch_size=batch_size,
verbose=1,
validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
callbacks=[
......@@ -399,6 +415,7 @@ else:
kl_warmup_epochs=kl_wu,
mmd_warmup_epochs=mmd_wu,
predictor=predictor,
overlap_loss=overlap_loss,
**hparams
).build()
gmvaep.build(input_dict_train[input_type].shape)
......@@ -423,7 +440,7 @@ else:
x=input_dict_train[input_type],
y=input_dict_train[input_type],
epochs=250,
batch_size=512,
batch_size=batch_size,
verbose=1,
validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
callbacks=callbacks_,
......@@ -433,7 +450,7 @@ else:
x=input_dict_train[input_type][:-1],
y=[input_dict_train[input_type][:-1], input_dict_train[input_type][1:]],
epochs=250,
batch_size=512,
batch_size=batch_size,
verbose=1,
validation_data=(
input_dict_val[input_type][:-1],
......
......@@ -2,6 +2,7 @@
from itertools import combinations
from keras import backend as K
from scipy.stats import wasserstein_distance
from sklearn.metrics import silhouette_score
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
......@@ -133,24 +134,26 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
class MMDiscrepancyLayer(Layer):
"""
Identity transform layer that adds MM discrepancy
Identity transform layer that adds MM Discrepancy
to the final model loss.
"""
def __init__(self, prior, beta=1.0, *args, **kwargs):
def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
self.is_placeholder = True
self.batch_size = batch_size
self.beta = beta
self.prior = prior
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
def get_config(self):
config = super().get_config().copy()
config.update({"batch_size": self.batch_size})
config.update({"beta": self.beta})
config.update({"prior": self.prior})
return config
def call(self, z, **kwargs):
true_samples = self.prior.sample(1)
true_samples = self.prior.sample(self.batch_size)
mmd_batch = self.beta * compute_mmd([true_samples, z])
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd")
......@@ -166,18 +169,10 @@ class Gaussian_mixture_overlap(Layer):
"""
def __init__(
self,
lat_dims,
n_components,
metric="mmd",
loss=False,
samples=100,
*args,
**kwargs
self, lat_dims, n_components, loss=False, samples=100, *args, **kwargs
):
self.lat_dims = lat_dims
self.n_components = n_components
self.metric = metric
self.loss = loss
self.samples = samples
super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)
......@@ -186,7 +181,6 @@ class Gaussian_mixture_overlap(Layer):
config = super().get_config().copy()
config.update({"lat_dims": self.lat_dims})
config.update({"n_components": self.n_components})
config.update({"metric": self.metric})
config.update({"loss": self.loss})
config.update({"samples": self.samples})
return config
......@@ -204,8 +198,7 @@ class Gaussian_mixture_overlap(Layer):
dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
if self.metric == "mmd":
### MMD-based overlap ###
intercomponent_mmd = K.mean(
tf.convert_to_tensor(
[
......@@ -223,9 +216,6 @@ class Gaussian_mixture_overlap(Layer):
if self.loss:
self.add_loss(-intercomponent_mmd, inputs=[target])
elif self.metric == "wasserstein":
pass
return target
......@@ -250,7 +240,7 @@ class Latent_space_control(Layer):
tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
)
# Adds Silhouette score controling overlap between clusters
# Adds Silhouette score controlling overlap between clusters
hard_labels = tf.math.argmax(z_cat, axis=1)
silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
self.add_metric(silhouette, aggregation="mean", name="silhouette")
......
......@@ -5,7 +5,7 @@ from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.activations import softplus
from tensorflow.keras.callbacks import LambdaCallback
from tensorflow.keras.constraints import UnitNorm
from tensorflow.keras.initializers import he_uniform, Orthogonal, RandomNormal
from tensorflow.keras.initializers import he_uniform, Orthogonal
from tensorflow.keras.layers import BatchNormalization, Bidirectional
from tensorflow.keras.layers import Dense, Dropout, LSTM
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
......@@ -155,6 +155,7 @@ class SEQ_2_SEQ_GMVAE:
def __init__(
self,
input_shape,
batch_size=512,
units_conv=256,
units_lstm=256,
units_dense2=64,
......@@ -167,10 +168,10 @@ class SEQ_2_SEQ_GMVAE:
prior="standard_normal",
number_of_components=1,
predictor=True,
overlap_metric="mmd",
overlap_loss=False,
):
self.input_shape = input_shape
self.batch_size = batch_size
self.CONV_filters = units_conv
self.LSTM_units_1 = units_lstm
self.LSTM_units_2 = int(units_lstm / 2)
......@@ -185,7 +186,6 @@ class SEQ_2_SEQ_GMVAE:
self.mmd_warmup = mmd_warmup_epochs
self.number_of_components = number_of_components
self.predictor = predictor
self.overlap_metric = overlap_metric
self.overlap_loss = overlap_loss
if self.prior == "standard_normal":
......@@ -303,10 +303,7 @@ class SEQ_2_SEQ_GMVAE:
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
z_gauss = Gaussian_mixture_overlap(
self.ENCODING,
self.number_of_components,
metric=self.overlap_metric,
loss=self.overlap_loss,
self.ENCODING, self.number_of_components, loss=self.overlap_loss,
)(z_gauss)
z = tfpl.DistributionLambda(
......@@ -353,10 +350,12 @@ class SEQ_2_SEQ_GMVAE:
)
)
z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z)
z = MMDiscrepancyLayer(
batch_size=self.batch_size, prior=self.prior, beta=mmd_beta
)(z)
# Identity layer controlling clustering and latent space statistics
z = Latent_space_control()(z, z_gauss, z_cat)
z = Latent_space_control(loss=self.overlap_loss)(z, z_gauss, z_cat)
# Define and instantiate generator
generator = Model_D1(z)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment