Skip to content
Snippets Groups Projects
Commit e46d76a7 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Implemented shuffle parameter in preprocessing; shuffled validation data in model_training.py

parent a3c034cc
No related branches found
No related tags found
No related merge requests found
...@@ -198,11 +198,11 @@ class Gaussian_mixture_overlap(Layer): ...@@ -198,11 +198,11 @@ class Gaussian_mixture_overlap(Layer):
locs = (target[..., : self.lat_dims, k],) locs = (target[..., : self.lat_dims, k],)
scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k]) scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
dists.append(tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])) dists.append(
tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
)
print(dists)
dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists] dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists]
print(dists)
if self.metric == "mmd": if self.metric == "mmd":
...@@ -215,7 +215,7 @@ class Gaussian_mixture_overlap(Layer): ...@@ -215,7 +215,7 @@ class Gaussian_mixture_overlap(Layer):
dtype=tf.float32, dtype=tf.float32,
) )
) )
print(intercomponent_mmd)
self.add_metric( self.add_metric(
intercomponent_mmd, aggregation="mean", name="intercomponent_mmd" intercomponent_mmd, aggregation="mean", name="intercomponent_mmd"
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment