Commit a3c034cc authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented shuffle parameter in preprocessing; shuffled validation data in

parent 5411ea58
...@@ -151,7 +151,7 @@ class MMDiscrepancyLayer(Layer): ...@@ -151,7 +151,7 @@ class MMDiscrepancyLayer(Layer):
def call(self, z, **kwargs): def call(self, z, **kwargs):
true_samples = self.prior.sample(1) true_samples = self.prior.sample(1)
mmd_batch = self.beta * compute_mmd(true_samples, z) mmd_batch = self.beta * compute_mmd([true_samples, z])
self.add_loss(K.mean(mmd_batch), inputs=z) self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd") self.add_metric(mmd_batch, aggregation="mean", name="mmd")
self.add_metric(self.beta, aggregation="mean", name="mmd_rate") self.add_metric(self.beta, aggregation="mean", name="mmd_rate")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment