diff --git a/deepof/model_utils.py b/deepof/model_utils.py index ed26b94eb7f84cad870b7066f62a64ff9e38bdd2..391341995221234af8369af32defb9ee87631e1a 100644 --- a/deepof/model_utils.py +++ b/deepof/model_utils.py @@ -9,6 +9,8 @@ Functions and general utilities for the deepof tensorflow models. See documentat """ from itertools import combinations +from typing import Any, Tuple + from tensorflow.keras import backend as K from tensorflow.keras.constraints import Constraint from tensorflow.keras.layers import Layer @@ -136,7 +138,7 @@ def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: @tf.function -def compute_mmd(tensors: tuple) -> tf.Tensor: +def compute_mmd(tensors: Tuple[Any, Any]) -> tf.Tensor: """ Computes the MMD between the two specified vectors using a gaussian kernel. @@ -317,7 +319,9 @@ class DenseTranspose(Layer): """Updates Layer's build method""" self.biases = self.add_weight( - name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros" + name="bias", + shape=self.dense.get_input_at(-1).get_shape().as_list(), + initializer="zeros", ) super().build(batch_input_shape) @@ -390,6 +394,7 @@ class MMDiscrepancyLayer(Layer): """Updates Layer's call method""" true_samples = self.prior.sample(self.batch_size) + # noinspection PyTypeChecker mmd_batch = self.beta * compute_mmd((true_samples, z)) self.add_loss(K.mean(mmd_batch), inputs=z) self.add_metric(mmd_batch, aggregation="mean", name="mmd") @@ -428,7 +433,7 @@ class Gaussian_mixture_overlap(Layer): dists = [] for k in range(self.n_components): locs = (target[..., : self.lat_dims, k],) - scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k]) + scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k]) dists.append( tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1]) diff --git a/deepof/models.py b/deepof/models.py index 84c2a877f4183257d74848ea4af61dcaf8947cb3..66822d070d6c2fd9442eb79c5361d27762dd1cb8 100644 --- a/deepof/models.py +++ b/deepof/models.py @@ -507,9 +507,7 @@ class SEQ_2_SEQ_GMVAE: encoder = BatchNormalization()(encoder) # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder) - z_cat = Dense(self.number_of_components, activation="softmax",)( - encoder - ) + z_cat = Dense(self.number_of_components, activation="softmax",)(encoder) z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(z_cat) z_gauss = Dense( deepof.model_utils.tfpl.IndependentNormal.params_size( @@ -535,7 +533,7 @@ class SEQ_2_SEQ_GMVAE: tfd.Independent( tfd.Normal( loc=gauss[1][..., : self.ENCODING, k], - scale=softplus(gauss[1][..., self.ENCODING:, k]), + scale=softplus(gauss[1][..., self.ENCODING :, k]), ), reinterpreted_batch_ndims=1, ) diff --git a/deepof/train_utils.py b/deepof/train_utils.py index befc11924163ef5b60d73efe6f0a452c2dd0a268..e1810f3d91da160b32821ded5bec0ff3d168e647 100644 --- a/deepof/train_utils.py +++ b/deepof/train_utils.py @@ -61,7 +61,12 @@ def load_treatments(train_path): def get_callbacks( - X_train: np.array, batch_size: int, cp: bool, variational: bool, predictor: float, loss: str, + X_train: np.array, + batch_size: int, + cp: bool, + variational: bool, + predictor: float, + loss: str, ) -> List[Union[Any]]: """Generates callbacks for model training, including: - run_ID: run name, with coarse parameter details;