diff --git a/source/model_utils.py b/source/model_utils.py index cead80764876f94a70c6fefb70c5571a71e42baf..7918608abd8814a3b069d99a725deed78321b9cc 100644 --- a/source/model_utils.py +++ b/source/model_utils.py @@ -168,7 +168,7 @@ class Gaussian_mixture_overlap(Layer): """ def __init__( - self, lat_dims, n_components, loss=False, samples=100, *args, **kwargs + self, lat_dims, n_components, loss=False, samples=10, *args, **kwargs ): self.lat_dims = lat_dims self.n_components = n_components @@ -201,7 +201,7 @@ class Gaussian_mixture_overlap(Layer): intercomponent_mmd = K.mean( tf.convert_to_tensor( [ - tf.map_fn(compute_mmd, [dists[c[0]], dists[c[1]]], dtype=tf.float32) + tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]]) for c in combinations(range(len(dists)), 2) ], dtype=tf.float32,