diff --git a/deepof/model_utils.py b/deepof/model_utils.py index 87a23b05fbcbaacc53aedeb2b78c8e1b7ca1889a..5ba14df2b11bd52a2b8bb0a6cf7e00358014abd3 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -351,7 +351,7 @@ class MMDiscrepancyLayer(Layer): return z -class Gaussian_mixture_overlap(Layer): +class Gaussian_mixture_overlap(Layer): # pragma: no cover """ Identity layer that measures the overlap between the components of the latent Gaussian Mixture using a specified metric (MMD, Wasserstein, Fischer-Rao) @@ -365,6 +365,8 @@ class Gaussian_mixture_overlap(Layer): super(Gaussian_mixture_overlap, self).__init__(*args, **kwargs) def get_config(self): + """Updates Constraint metadata""" + config = super().get_config().copy() config.update({"lat_dims": self.lat_dims}) config.update({"n_components": self.n_components}) @@ -372,12 +374,14 @@ class Gaussian_mixture_overlap(Layer): config.update({"samples": self.samples}) return config - def call(self, target, loss=False): + @tf.function + def call(self, target, **kwargs): + """Updates Layer's call method""" dists = [] for k in range(self.n_components): locs = (target[..., : self.lat_dims, k],) - scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k]) + scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k]) dists.append( tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1]) @@ -385,7 +389,7 @@ class Gaussian_mixture_overlap(Layer): dists = [tf.transpose(gauss.sample(self.samples), [1, 0, 2]) for gauss in dists] - ### MMD-based overlap ### + # MMD-based overlap # intercomponent_mmd = K.mean( tf.convert_to_tensor( [ @@ -415,13 +419,15 @@ class Dead_neuron_control(Layer): def __init__(self, *args, **kwargs): super(Dead_neuron_control, self).__init__(*args, **kwargs) - def call(self, z, z_gauss, z_cat, **kwargs): + # noinspection PyMethodOverriding + def call(self, target, **kwargs): + """Updates Layer's call method""" # Adds metric that monitors dead neurons in the latent space self.add_metric( - tf.math.zero_fraction(z_gauss), aggregation="mean", name="dead_neurons" + tf.math.zero_fraction(target), aggregation="mean", name="dead_neurons" ) - return z + return target class Entropy_regulariser(Layer): @@ -429,18 +435,24 @@ class Entropy_regulariser(Layer): Identity layer that adds cluster weight entropy to the loss function """ - def __init__(self, weight=1.0, *args, **kwargs): + def __init__(self, weight=1.0, axis=1, *args, **kwargs): self.weight = weight + self.axis = axis super(Entropy_regulariser, self).__init__(*args, **kwargs) def get_config(self): + """Updates Constraint metadata""" + config = super().get_config().copy() config.update({"weight": self.weight}) + config.update({"axis": self.axis}) def call(self, z, **kwargs): + """Updates Layer's call method""" + # axis=1 increases the entropy of a cluster across instances # axis=0 increases the entropy of the assignment for a given instance - entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=1) + entropy = K.sum(tf.multiply(z + 1e-5, tf.math.log(z) + 1e-5), axis=self.axis) # Adds metric that monitors dead neurons in the latent space self.add_metric(entropy, aggregation="mean", name="-weight_entropy") diff --git a/deepof/models.py b/deepof/models.py index 29e5b8e6335427d5505599ba606099ce0144c07f..1eb5d7968de801cbb7987b456c4b50bacd670def 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -335,6 +335,9 @@ class SEQ_2_SEQ_GMVAE: z_gauss = Reshape([2 * self.ENCODING, self.number_of_components])(z_gauss) + # Identity layer controlling for dead neurons in the Gaussian Mixture posterior + z_gauss = Dead_neuron_control()(z_gauss) + if self.overlap_loss: z_gauss = Gaussian_mixture_overlap( self.ENCODING, self.number_of_components, loss=self.overlap_loss, @@ -387,9 +390,6 @@ class SEQ_2_SEQ_GMVAE: batch_size=self.batch_size, prior=self.prior, beta=mmd_beta )(z) - # Identity layer controlling clustering and latent space statistics - z = Dead_neuron_control()(z, z_gauss, z_cat) - # Define and instantiate generator generator = Model_D1(z) generator = Model_B1(generator) diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index e70115b91a09f90370645f8cb8a8848014bdcad2..b350ab0b4025b4d790f9fa660eec939ce7cf3771 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -188,7 +188,7 @@ def test_MMDiscrepancyLayer(): y = np.random.randint(0, 2, [1500, 1]) prior = tfd.Independent( - tfd.Normal(loc=tf.zeros(10), scale=1, ), reinterpreted_batch_ndims=1, + tfd.Normal(loc=tf.zeros(10), scale=1,), reinterpreted_batch_ndims=1, ) dense_1 = tf.keras.layers.Dense(10) @@ -197,9 +197,10 @@ def test_MMDiscrepancyLayer(): d = dense_1(i) x = tfpl.DistributionLambda( lambda dense: tfd.Independent( - tfd.Normal(loc=dense, scale=1, ), reinterpreted_batch_ndims=1, + tfd.Normal(loc=dense, scale=1,), reinterpreted_batch_ndims=1, ) )(d) + x = deepof.model_utils.MMDiscrepancyLayer( 100, prior, beta=tf.keras.backend.variable(1.0, name="kl_beta") )(x) @@ -213,21 +214,21 @@ def test_MMDiscrepancyLayer(): assert type(fit) == tf.python.keras.callbacks.History -# -# -# @settings(deadline=None) -# @given() -# def test_gaussian_mixture_overlap(): -# pass -# -# -# @settings(deadline=None) -# @given() -# def test_dead_neuron_control(): -# pass -# -# -# @settings(deadline=None) -# @given() +def test_dead_neuron_control(): + X = np.random.uniform(0, 10, [1500, 5]) + y = np.random.randint(0, 2, [1500, 1]) + + test_model = tf.keras.Sequential() + test_model.add(tf.keras.layers.Dense(1)) + test_model.add(deepof.model_utils.Dead_neuron_control()) + + test_model.compile( + loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.SGD(), + ) + + fit = test_model.fit(X, y, epochs=10, batch_size=100) + assert type(fit) == tf.python.keras.callbacks.History + + # def test_entropy_regulariser(): # pass