Skip to content
Snippets Groups Projects
Commit e068eb57 authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Replaced for loop with vectorised mapping on ClusterOverlap regularization layer

parent f7772dc7
Branches
Tags
No related merge requests found
Pipeline #101805 passed
...@@ -51,14 +51,41 @@ def test_compute_shannon_entropy(tensor): ...@@ -51,14 +51,41 @@ def test_compute_shannon_entropy(tensor):
tensor=arrays( tensor=arrays(
shape=[100, 10], shape=[100, 10],
dtype=float, dtype=float,
unique=False, unique=True,
elements=st.floats(min_value=0.0, max_value=10.0), elements=st.floats(min_value=0.0, max_value=10.0),
) ),
k=st.integers(min_value=5, max_value=20),
) )
def test_k_nearest_neighbors(tensor, k): def test_k_nearest_neighbors(tensor, k):
deepof_knn = deepof.model_utils.get_k_nearest_neighbors(tensor, k, 0) deepof_knn = deepof.model_utils.get_k_nearest_neighbors(tensor, k, 0)
sklearn_knn = NearestNeighbors().fit(tensor) sklearn_knn = NearestNeighbors(k).fit(tensor)
assert np.allclose(deepof_knn, sklearn_knn.kneighbors()) sklearn_knn = sklearn_knn.kneighbors(tensor[0].reshape(1, -1))[1].flatten()
assert np.allclose(deepof_knn.numpy(), sorted(sklearn_knn))
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(
tensor=arrays(
shape=[100, 10],
dtype=float,
unique=True,
elements=st.floats(min_value=0.0, max_value=10.0),
),
clusters=arrays(
shape=[100],
dtype=int,
unique=False,
elements=st.integers(min_value=0, max_value=10),
),
k=st.integers(min_value=5, max_value=20),
)
def test_get_neighbourhood_entropy(tensor, clusters, k):
neighborhood_entropy = deepof.model_utils.get_neighbourhood_entropy(
0, tensor, clusters, k
).numpy()
assert isinstance(neighborhood_entropy, np.float32)
@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow]) @settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment