Commit 9e636771 authored by lucas_miranda's avatar lucas_miranda
Browse files

Implemented mmd between Gaussian mixture components to track cluster...

Implemented mmd between Gaussian mixture components to track cluster overlapping. Adding a term to the loss function is an option
parent aac5e924
......@@ -102,7 +102,7 @@ parser.add_argument(
"-bs",
help="set training batch size. Defaults to 512",
type=int,
default=32
default=512
)
args = parser.parse_args()
......
......@@ -201,7 +201,7 @@ class Gaussian_mixture_overlap(Layer):
intercomponent_mmd = K.mean(
tf.convert_to_tensor(
[
tf.vectorized_map(compute_mmd, [dists[c[0]], dists[c[1]]])
tf.map_fn(compute_mmd, [dists[c[0]], dists[c[1]]], dtype=tf.float32)
for c in combinations(range(len(dists)), 2)
],
dtype=tf.float32,
......
......@@ -448,6 +448,7 @@ class SEQ_2_SEQ_GMVAE:
# TODO:
# - Add a parameter to enable/disable the mmd control (too slow if not needed)
# - design clustering-conscious hyperparameter tuning pipeline
# - execute the pipeline ;)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment