diff --git a/model_training.py b/model_training.py index c222b819e9c51be85795586f5dfed1a2149e6da0..4d3c0078b11d42beebd496e5bc3a3f900fa97d66 100644 --- a/model_training.py +++ b/model_training.py @@ -90,6 +90,20 @@ parser.add_argument( default=16, type=int, ) +parser.add_argument( + "--overlap-loss", + "-ol", + help="If True, adds the negative MMD between all components of the latent Gaussian mixture to the loss function", + default=False, + type=str2bool +) +parser.add_argument( + "--batch-size", + "-bs", + help="set training batch size. Defaults to 512", + type=int, + default=512 +) args = parser.parse_args() train_path = os.path.abspath(args.train_path) @@ -103,6 +117,8 @@ kl_wu = args.kl_warmup mmd_wu = args.mmd_warmup hparams = args.hyperparameters encoding = args.encoding_size +batch_size = args.batch_size +overlap_loss = args.overlap_loss if not train_path: raise ValueError("Set a valid data path for the training to run") @@ -372,7 +388,7 @@ if not variational: x=input_dict_train[input_type], y=input_dict_train[input_type], epochs=250, - batch_size=512, + batch_size=batch_size, verbose=1, validation_data=(input_dict_val[input_type], input_dict_val[input_type]), callbacks=[ @@ -399,6 +415,7 @@ else: kl_warmup_epochs=kl_wu, mmd_warmup_epochs=mmd_wu, predictor=predictor, + overlap_loss=overlap_loss, **hparams ).build() gmvaep.build(input_dict_train[input_type].shape) @@ -423,7 +440,7 @@ else: x=input_dict_train[input_type], y=input_dict_train[input_type], epochs=250, - batch_size=512, + batch_size=batch_size, verbose=1, validation_data=(input_dict_val[input_type], input_dict_val[input_type]), callbacks=callbacks_, @@ -433,7 +450,7 @@ else: x=input_dict_train[input_type][:-1], y=[input_dict_train[input_type][:-1], input_dict_train[input_type][1:]], epochs=250, - batch_size=512, + batch_size=batch_size, verbose=1, validation_data=( input_dict_val[input_type][:-1], diff --git a/source/model_utils.py b/source/model_utils.py index b27cfd3fc4882ece1218630f5edc1dc010aef038..e705390be1040a97726547616885c3ec69bda4bc 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -2,6 +2,7 @@ from itertools import combinations from keras import backend as K +from scipy.stats import wasserstein_distance from sklearn.metrics import silhouette_score from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Layer @@ -133,24 +134,26 @@ class KLDivergenceLayer(tfpl.KLDivergenceAddLoss): class MMDiscrepancyLayer(Layer): """ - Identity transform layer that adds MM discrepancy + Identity transform layer that adds MM Discrepancy to the final model loss. """ - def __init__(self, prior, beta=1.0, *args, **kwargs): + def __init__(self, batch_size, prior, beta=1.0, *args, **kwargs): self.is_placeholder = True + self.batch_size = batch_size self.beta = beta self.prior = prior super(MMDiscrepancyLayer, self).__init__(*args, **kwargs) def get_config(self): config = super().get_config().copy() + config.update({"batch_size": self.batch_size}) config.update({"beta": self.beta}) config.update({"prior": self.prior}) return config def call(self, z, **kwargs): - true_samples = self.prior.sample(1) + true_samples = self.prior.sample(self.batch_size) mmd_batch = self.beta * compute_mmd([true_samples, z]) self.add_loss(K.mean(mmd_batch), inputs=z) self.add_metric(mmd_batch, aggregation="mean", name="mmd") @@ -166,18 +169,10 @@ class Gaussian_mixture_overlap(Layer): """ def __init__( - self, - lat_dims, - n_components, - metric="mmd", - loss=False, - samples=100, - *args, - **kwargs + self, lat_dims, n_components, loss=False, samples=100, *args, **kwargs ): self.lat_dims = lat_dims self.n_components = n_components - self.metric = metric self.loss = loss self.samples = samples super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs) @@ -186,7 +181,6 @@ class Gaussian_mixture_overlap(Layer): config = super().get_config().copy() config.update({"lat_dims": self.lat_dims}) config.update({"n_components": self.n_components}) - config.update({"metric": self.metric}) config.update({"loss": self.loss}) config.update({"samples": self.samples}) return config @@ -204,27 +198,23 @@ class Gaussian_mixture_overlap(Layer): dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists] - if self.metric == "mmd": - - intercomponent_mmd = K.mean( - tf.convert_to_tensor( - [ - tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]]) - for c in combinations(range(len(dists)), 2) - ], - dtype=tf.float32, - ) - ) - - self.add_metric( - intercomponent_mmd, aggregation="mean", name="intercomponent_mmd" + ### MMD-based overlap ### + intercomponent_mmd = K.mean( + tf.convert_to_tensor( + [ + tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]]) + for c in combinations(range(len(dists)), 2) + ], + dtype=tf.float32, ) + ) - if self.loss: - self.add_loss(-intercomponent_mmd, inputs=[target]) + self.add_metric( + intercomponent_mmd, aggregation="mean", name="intercomponent_mmd" + ) - elif self.metric == "wasserstein": - pass + if self.loss: + self.add_loss(-intercomponent_mmd, inputs=[target]) return target @@ -250,7 +240,7 @@ class Latent_space_control(Layer): tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons" ) - # Adds Silhouette score controling overlap between clusters + # Adds Silhouette score controlling overlap between clusters hard_labels = tf.math.argmax(z_cat, axis=1) silhouette = tf.numpy_function(silhouette_score, [z, hard_labels], tf.float32) self.add_metric(silhouette, aggregation="mean", name="silhouette") diff --git a/source/models.py b/source/models.py index 90b69b349597b738bf67746bc50a4c276b8e00ac..afdf3843cb78f5b6a713194658de67dee6b4dcab 100644 --- a/source/models.py +++ b/source/models.py @@ -5,7 +5,7 @@ from tensorflow.keras import Input, Model, Sequential from tensorflow.keras.activations import softplus from tensorflow.keras.callbacks import LambdaCallback from tensorflow.keras.constraints import UnitNorm -from tensorflow.keras.initializers import he_uniform, Orthogonal, RandomNormal +from tensorflow.keras.initializers import he_uniform, Orthogonal from tensorflow.keras.layers import BatchNormalization, Bidirectional from tensorflow.keras.layers import Dense, Dropout, LSTM from tensorflow.keras.layers import RepeatVector, Reshape, TimeDistributed @@ -155,6 +155,7 @@ class SEQ_2_SEQ_GMVAE: def __init__( self, input_shape, + batch_size=512, units_conv=256, units_lstm=256, units_dense2=64, @@ -167,10 +168,10 @@ class SEQ_2_SEQ_GMVAE: prior="standard_normal", number_of_components=1, predictor=True, - overlap_metric="mmd", overlap_loss=False, ): self.input_shape = input_shape + self.batch_size = batch_size self.CONV_filters = units_conv self.LSTM_units_1 = units_lstm self.LSTM_units_2 = int(units_lstm / 2) @@ -185,7 +186,6 @@ class SEQ_2_SEQ_GMVAE: self.mmd_warmup = mmd_warmup_epochs self.number_of_components = number_of_components self.predictor = predictor - self.overlap_metric = overlap_metric self.overlap_loss = overlap_loss if self.prior == "standard_normal": @@ -303,10 +303,7 @@ class SEQ_2_SEQ_GMVAE: z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss) z_gauss = Gaussian_mixture_overlap( - self.ENCODING, - self.number_of_components, - metric=self.overlap_metric, - loss=self.overlap_loss, + self.ENCODING, self.number_of_components, loss=self.overlap_loss, )(z_gauss) z = tfpl.DistributionLambda( @@ -353,10 +350,12 @@ class SEQ_2_SEQ_GMVAE: ) ) - z = MMDiscrepancyLayer(prior=self.prior, beta=mmd_beta)(z) + z = MMDiscrepancyLayer( + batch_size=self.batch_size, prior=self.prior, beta=mmd_beta + )(z) # Identity layer controlling clustering and latent space statistics - z = Latent_space_control()(z, z_gauss, z_cat) + z = Latent_space_control(loss=self.overlap_loss)(z, z_gauss, z_cat) # Define and instantiate generator generator = Model_D1(z)