diff --git a/source/models.py b/source/models.py index 5c11780e9f8111dfff3991dc466a7638fb5ae448..95f5a4034a1d79dfacadef59e3f6b81cdc510c05 100644 --- a/source/models.py +++ b/source/models.py @@ -203,7 +203,7 @@ class SEQ_2_SEQ_GMVAE: loc=tf.random.uniform( shape=[self.ENCODING], minval=0, maxval=15 ), - scale=10, + scale=1, ), reinterpreted_batch_ndims=1, )