diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index ed26b94eb7f84cad870b7066f62a64ff9e38bdd2..391341995221234af8369af32defb9ee87631e1a 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -9,6 +9,8 @@ Functions and general utilities for the deepof tensorflow models. See documentat
 """
 
 from itertools import combinations
+from typing import Any, Tuple
+
 from tensorflow.keras import backend as K
 from tensorflow.keras.constraints import Constraint
 from tensorflow.keras.layers import Layer
@@ -136,7 +138,7 @@ def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
 
 
 @tf.function
-def compute_mmd(tensors: tuple) -> tf.Tensor:
+def compute_mmd(tensors: Tuple[Any, Any]) -> tf.Tensor:
     """
 
         Computes the MMD between the two specified vectors using a gaussian kernel.
@@ -317,7 +319,9 @@ class DenseTranspose(Layer):
         """Updates Layer's build method"""
 
         self.biases = self.add_weight(
-            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
+            name="bias",
+            shape=self.dense.get_input_at(-1).get_shape().as_list(),
+            initializer="zeros",
         )
         super().build(batch_input_shape)
 
@@ -390,6 +394,7 @@ class MMDiscrepancyLayer(Layer):
         """Updates Layer's call method"""
 
         true_samples = self.prior.sample(self.batch_size)
+        # noinspection PyTypeChecker
         mmd_batch = self.beta * compute_mmd((true_samples, z))
         self.add_loss(K.mean(mmd_batch), inputs=z)
         self.add_metric(mmd_batch, aggregation="mean", name="mmd")
@@ -428,7 +433,7 @@ class Gaussian_mixture_overlap(Layer):
         dists = []
         for k in range(self.n_components):
             locs = (target[..., : self.lat_dims, k],)
-            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
+            scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
 
             dists.append(
                 tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
diff --git a/deepof/models.py b/deepof/models.py
index 84c2a877f4183257d74848ea4af61dcaf8947cb3..66822d070d6c2fd9442eb79c5361d27762dd1cb8 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -507,9 +507,7 @@ class SEQ_2_SEQ_GMVAE:
         encoder = BatchNormalization()(encoder)
 
         # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
-        z_cat = Dense(self.number_of_components, activation="softmax",)(
-            encoder
-        )
+        z_cat = Dense(self.number_of_components, activation="softmax",)(encoder)
         z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(z_cat)
         z_gauss = Dense(
             deepof.model_utils.tfpl.IndependentNormal.params_size(
@@ -535,7 +533,7 @@ class SEQ_2_SEQ_GMVAE:
                     tfd.Independent(
                         tfd.Normal(
                             loc=gauss[1][..., : self.ENCODING, k],
-                            scale=softplus(gauss[1][..., self.ENCODING:, k]),
+                            scale=softplus(gauss[1][..., self.ENCODING :, k]),
                         ),
                         reinterpreted_batch_ndims=1,
                     )
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index befc11924163ef5b60d73efe6f0a452c2dd0a268..e1810f3d91da160b32821ded5bec0ff3d168e647 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -61,7 +61,12 @@ def load_treatments(train_path):
 
 
 def get_callbacks(
-    X_train: np.array, batch_size: int, cp: bool, variational: bool, predictor: float, loss: str,
+    X_train: np.array,
+    batch_size: int,
+    cp: bool,
+    variational: bool,
+    predictor: float,
+    loss: str,
 ) -> List[Union[Any]]:
     """Generates callbacks for model training, including:
         - run_ID: run name, with coarse parameter details;