diff --git a/source/model_utils.py b/source/model_utils.py index ac45f8472f03b334526f9aa3e29d32c6936a9351..a2e824cee5260e10da3ec53359015b57efa1bbfb 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -198,11 +198,11 @@ class Gaussian_mixture_overlap(Layer): locs = (target[..., : self.lat_dims, k],) scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k]) - dists.append(tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])) + dists.append( + tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1]) + ) - print(dists) dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists] - print(dists) if self.metric == "mmd": @@ -215,7 +215,7 @@ class Gaussian_mixture_overlap(Layer): dtype=tf.float32, ) ) - print(intercomponent_mmd) + self.add_metric( intercomponent_mmd, aggregation="mean", name="intercomponent_mmd" )