Commit 7725eb28 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented mmd between Gaussian mixture components to track cluster...

Implemented mmd between Gaussian mixture components to track cluster overlapping. Adding a term to the loss function is an option
parent 2b01ccdd
...@@ -90,6 +90,20 @@ parser.add_argument( ...@@ -90,6 +90,20 @@ parser.add_argument(
default=16, default=16,
type=int, type=int,
) )
parser.add_argument(
"--overlap-loss",
"-ol",
help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function",
default=False,
type=str2bool
)
parser.add_argument(
"--batch-size",
"-bs",
help="set training batch size. Defaults to 512",
type=int,
default=512
)
args = parser.parse_args() args = parser.parse_args()
train_path = os.path.abspath(args.train_path) train_path = os.path.abspath(args.train_path)
...@@ -103,6 +117,8 @@ kl_wu = args.kl_warmup ...@@ -103,6 +117,8 @@ kl_wu = args.kl_warmup
mmd_wu = args.mmd_warmup mmd_wu = args.mmd_warmup
hparams = args.hyperparameters hparams = args.hyperparameters
encoding = args.encoding_size encoding = args.encoding_size
batch_size = args.batch_size
overlap_loss = args.overlap_loss
if not train_path: if not train_path:
raise ValueError("Set a valid data path for the training to run") raise ValueError("Set a valid data path for the training to run")
...@@ -372,7 +388,7 @@ if not variational: ...@@ -372,7 +388,7 @@ if not variational:
x=input_dict_train[input_type], x=input_dict_train[input_type],
y=input_dict_train[input_type], y=input_dict_train[input_type],
epochs=250, epochs=250,
batch_size=512, batch_size=batch_size,
verbose=1, verbose=1,
validation_data=(input_dict_val[input_type], input_dict_val[input_type]), validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
callbacks=[ callbacks=[
...@@ -399,6 +415,7 @@ else: ...@@ -399,6 +415,7 @@ else:
kl_warmup_epochs=kl_wu, kl_warmup_epochs=kl_wu,
mmd_warmup_epochs=mmd_wu, mmd_warmup_epochs=mmd_wu,
predictor=predictor, predictor=predictor,
overlap_loss=overlap_loss,
**hparams **hparams
).build() ).build()
gmvaep.build(input_dict_train[input_type].shape) gmvaep.build(input_dict_train[input_type].shape)
...@@ -423,7 +440,7 @@ else: ...@@ -423,7 +440,7 @@ else:
x=input_dict_train[input_type], x=input_dict_train[input_type],
y=input_dict_train[input_type], y=input_dict_train[input_type],
epochs=250, epochs=250,
batch_size=512, batch_size=batch_size,
verbose=1, verbose=1,
validation_data=(input_dict_val[input_type], input_dict_val[input_type]), validation_data=(input_dict_val[input_type], input_dict_val[input_type]),
callbacks=callbacks_, callbacks=callbacks_,
...@@ -433,7 +450,7 @@ else: ...@@ -433,7 +450,7 @@ else:
x=input_dict_train[input_type][:-1], x=input_dict_train[input_type][:-1],
y=[input_dict_train[input_type][:-1], input_dict_train[input_type][1:]], y=[input_dict_train[input_type][:-1], input_dict_train[input_type][1:]],
epochs=250, epochs=250,
batch_size=512, batch_size=batch_size,
verbose=1, verbose=1,
validation_data=( validation_data=(
input_dict_val[input_type][:-1], input_dict_val[input_type][:-1],
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from itertools import combinations from itertools import combinations
from keras import backend as K from keras import backend as K
from scipy.stats import wasserstein_distance
from sklearn.metrics import silhouette_score from sklearn.metrics import silhouette_score
from tensorflow.keras.constraints import Constraint from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer from tensorflow.keras.layers import Layer
...@@ -133,24 +134,26 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): ...@@ -133,24 +134,26 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss):
class MMDiscrepancyLayer(Layer): class MMDiscrepancyLayer(Layer):
""" """
Identity transform layer that adds MM discrepancy Identity transform layer that adds MM Discrepancy
to the final model loss. to the final model loss.
""" """
def __init__(self, prior, beta=1.0, *args, **kwargs): def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs):
self.is_placeholder = True self.is_placeholder = True
self.batch_size = batch_size
self.beta = beta self.beta = beta
self.prior = prior self.prior = prior
super(MMDiscrepancyLayer, self).__init__(*args, **kwargs) super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
def get_config(self): def get_config(self):
config = super().get_config().copy() config = super().get_config().copy()
config.update({"batch_size": self.batch_size})
config.update({"beta": self.beta}) config.update({"beta": self.beta})
config.update({"prior": self.prior}) config.update({"prior": self.prior})
return config return config
def call(self, z, **kwargs): def call(self, z, **kwargs):
true_samples = self.prior.sample(1) true_samples = self.prior.sample(self.batch_size)
mmd_batch = self.beta * compute_mmd([true_samples, z]) mmd_batch = self.beta * compute_mmd([true_samples, z])
self.add_loss(K.mean(mmd_batch), inputs=z) self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd") self.add_metric(mmd_batch, aggregation="mean", name="mmd")
...@@ -166,18 +169,10 @@ class Gaussian_mixture_overlap(Layer): ...@@ -166,18 +169,10 @@ class Gaussian_mixture_overlap(Layer):
""" """
def __init__( def __init__(
self, self, lat_dims, n_components, loss=False, samples=100, *args, **kwargs
lat_dims,
n_components,
metric="mmd",
loss=False,
samples=100,
*args,
**kwargs
): ):
self.lat_dims = lat_dims self.lat_dims = lat_dims
self.n_components = n_components self.n_components = n_components
self.metric = metric
self.loss = loss self.loss = loss
self.samples = samples self.samples = samples
super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs) super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs)
...@@ -186,7 +181,6 @@ class Gaussian_mixture_overlap(Layer): ...@@ -186,7 +181,6 @@ class Gaussian_mixture_overlap(Layer):
config = super().get_config().copy() config = super().get_config().copy()
config.update({"lat_dims": self.lat_dims}) config.update({"lat_dims": self.lat_dims})
config.update({"n_components": self.n_components}) config.update({"n_components": self.n_components})
config.update({"metric": self.metric})
config.update({"loss": self.loss}) config.update({"loss": self.loss})
config.update({"samples": self.samples}) config.update({"samples": self.samples})
return config return config
...@@ -204,27 +198,23 @@ class Gaussian_mixture_overlap(Layer): ...@@ -204,27 +198,23 @@ class Gaussian_mixture_overlap(Layer):
dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists] dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
if self.metric == "mmd": ### MMD-based overlap ###
intercomponent_mmd = K.mean(
intercomponent_mmd = K.mean( tf.convert_to_tensor(
tf.convert_to_tensor( [
[ tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]]) for c in combinations(range(len(dists)), 2)
for c in combinations(range(len(dists)), 2) ],
], dtype=tf.float32,
dtype=tf.float32,
)
)
self.add_metric(
intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
) )
)
if self.loss: self.add_metric(
self.add_loss(-intercomponent_mmd, inputs=[target]) intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
)
elif self.metric == "wasserstein": if self.loss:
pass self.add_loss(-intercomponent_mmd, inputs=[target])
return target return target
...@@ -250,7 +240,7 @@ class Latent_space_control(Layer): ...@@ -250,7 +240,7 @@ class Latent_space_control(Layer):
tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons" tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons"
) )
# Adds Silhouette score controling overlap between clusters # Adds Silhouette score controlling overlap between clusters
hard_labels = tf.math.argmax(z_cat, axis=1) hard_labels = tf.math.argmax(z_cat, axis=1)
silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32) silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32)
self.add_metric(silhouette, aggregation="mean", name="silhouette") self.add_metric(silhouette, aggregation="mean", name="silhouette")
......
...@@ -5,7 +5,7 @@ from tensorflow.keras import Input, Model, Sequential ...@@ -5,7 +5,7 @@ from tensorflow.keras import Input, Model, Sequential
from tensorflow.keras.activations import softplus from tensorflow.keras.activations import softplus
from tensorflow.keras.callbacks import LambdaCallback from tensorflow.keras.callbacks import LambdaCallback
from tensorflow.keras.constraints import UnitNorm from tensorflow.keras.constraints import UnitNorm
from tensorflow.keras.initializers import he_uniform, Orthogonal, RandomNormal from tensorflow.keras.initializers import he_uniform, Orthogonal
from tensorflow.keras.layers import BatchNormalization, Bidirectional from tensorflow.keras.layers import BatchNormalization, Bidirectional
from tensorflow.keras.layers import Dense, Dropout, LSTM from tensorflow.keras.layers import Dense, Dropout, LSTM
from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed
...@@ -155,6 +155,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -155,6 +155,7 @@ class SEQ_2_SEQ_GMVAE:
def __init__( def __init__(
self, self,
input_shape, input_shape,
batch_size=512,
units_conv=256, units_conv=256,
units_lstm=256, units_lstm=256,
units_dense2=64, units_dense2=64,
...@@ -167,10 +168,10 @@ class SEQ_2_SEQ_GMVAE: ...@@ -167,10 +168,10 @@ class SEQ_2_SEQ_GMVAE:
prior="standard_normal", prior="standard_normal",
number_of_components=1, number_of_components=1,
predictor=True, predictor=True,
overlap_metric="mmd",
overlap_loss=False, overlap_loss=False,
): ):
self.input_shape = input_shape self.input_shape = input_shape
self.batch_size = batch_size
self.CONV_filters = units_conv self.CONV_filters = units_conv
self.LSTM_units_1 = units_lstm self.LSTM_units_1 = units_lstm
self.LSTM_units_2 = int(units_lstm / 2) self.LSTM_units_2 = int(units_lstm / 2)
...@@ -185,7 +186,6 @@ class SEQ_2_SEQ_GMVAE: ...@@ -185,7 +186,6 @@ class SEQ_2_SEQ_GMVAE:
self.mmd_warmup = mmd_warmup_epochs self.mmd_warmup = mmd_warmup_epochs
self.number_of_components = number_of_components self.number_of_components = number_of_components
self.predictor = predictor self.predictor = predictor
self.overlap_metric = overlap_metric
self.overlap_loss = overlap_loss self.overlap_loss = overlap_loss
if self.prior == "standard_normal": if self.prior == "standard_normal":
...@@ -303,10 +303,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -303,10 +303,7 @@ class SEQ_2_SEQ_GMVAE:
z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss) z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss)
z_gauss = Gaussian_mixture_overlap( z_gauss = Gaussian_mixture_overlap(
self.ENCODING, self.ENCODING, self.number_of_components, loss=self.overlap_loss,
self.number_of_components,
metric=self.overlap_metric,
loss=self.overlap_loss,
)(z_gauss) )(z_gauss)
z = tfpl.DistributionLambda( z = tfpl.DistributionLambda(
...@@ -353,10 +350,12 @@ class SEQ_2_SEQ_GMVAE: ...@@ -353,10 +350,12 @@ class SEQ_2_SEQ_GMVAE:
) )
) )
z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z) z = MMDiscrepancyLayer(
batch_size=self.batch_size, prior=self.prior, beta=mmd_beta
)(z)
# Identity layer controlling clustering and latent space statistics # Identity layer controlling clustering and latent space statistics
z = Latent_space_control()(z, z_gauss, z_cat) z = Latent_space_control(loss=self.overlap_loss)(z, z_gauss, z_cat)
# Define and instantiate generator # Define and instantiate generator
generator = Model_D1(z) generator = Model_D1(z)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment