Commit e46d76a7 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented shuffle parameter in preprocessing; shuffled validation data in model_training.py

parent a3c034cc
......@@ -198,11 +198,11 @@ class Gaussian_mixture_overlap(Layer):
locs = (target[..., : self.lat_dims, k],)
scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
dists.append(tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1]))
dists.append(
tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
)
print(dists)
dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
print(dists)
if self.metric == "mmd":
......@@ -215,7 +215,7 @@ class Gaussian_mixture_overlap(Layer):
dtype=tf.float32,
)
)
print(intercomponent_mmd)
self.add_metric(
intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment