Commit f29e99ed authored by lucas_miranda's avatar lucas_miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent cc2733c7
......@@ -13,6 +13,8 @@ from hypothesis import HealthCheck
from hypothesis import settings
from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays
from scipy.stats import entropy
from sklearn.neighbors import NearestNeighbors
import deepof.models
import deepof.model_utils
import numpy as np
......@@ -28,6 +30,37 @@ tfpl = tfp.layers
tfd = tfp.distributions
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(
tensor=arrays(
shape=[10],
dtype=int,
unique=False,
elements=st.integers(min_value=0, max_value=10),
)
)
def test_compute_shannon_entropy(tensor):
deepof_tensor_entropy = deepof.model_utils.compute_shannon_entropy(tensor).numpy()
assert np.allclose(
np.round(deepof_tensor_entropy, 4), entropy(np.bincount(tensor)), rtol=1e-3
)
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(
tensor=arrays(
shape=[100, 10],
dtype=float,
unique=False,
elements=st.floats(min_value=0.0, max_value=10.0),
)
)
def test_k_nearest_neighbors(tensor, k):
deepof_knn = deepof.model_utils.get_k_nearest_neighbors(tensor, k, 0)
sklearn_knn = NearestNeighbors().fit(tensor)
assert np.allclose(deepof_knn, sklearn_knn.kneighbors())
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(
tensor=arrays(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment