Commit 5b8b49c2 authored by lucas_miranda's avatar lucas_miranda
Browse files

Started implementing annealing mode for KL divergence

parent 8d314040
......@@ -151,7 +151,9 @@ def test_dense_transpose():
# noinspection PyCallingNonCallable,PyUnresolvedReferences
def test_KLDivergenceLayer():
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(annealing_mode=st.one_of(st.just("linear"), st.just("sigmoid")))
def test_KLDivergenceLayer(annealing_mode):
X = tf.random.uniform([10, 2], 0, 10)
y = np.random.randint(0, 1, [10, 1])
......@@ -181,7 +183,7 @@ def test_KLDivergenceLayer():
weight=1.0,
)(x)
kl_deepof = deepof.model_utils.KLDivergenceLayer(
distribution_b=prior, iters=1, warm_up_iters=0
distribution_b=prior, iters=1, warm_up_iters=0, annealing_mode=annealing_mode,
)(x)
test_model = tf.keras.Model(i, [kl_canon, kl_deepof])
......@@ -196,7 +198,9 @@ def test_KLDivergenceLayer():
# noinspection PyUnresolvedReferences
def test_MMDiscrepancyLayer():
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(annealing_mode=st.one_of(st.just("linear"), st.just("sigmoid")))
def test_MMDiscrepancyLayer(annealing_mode):
X = tf.random.uniform([1500, 10], 0, 10)
y = np.random.randint(0, 2, [1500, 1])
......@@ -227,6 +231,7 @@ def test_MMDiscrepancyLayer():
prior=prior,
iters=1,
warm_up_iters=0,
annealing_mode=annealing_mode
)(x)
test_model = tf.keras.Model(i, x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment