From f00205fc7646f1eb16b1041e54874a9f78db4234 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Sun, 22 Nov 2020 01:56:59 +0100
Subject: [PATCH] updated dependencies

---
 deepof/model_utils.py | 11 ++++++++---
 deepof/models.py      |  6 ++----
 deepof/train_utils.py |  7 ++++++-
 3 files changed, 16 insertions(+), 8 deletions(-)

diff --git a/deepof/model_utils.py b/deepof/model_utils.py
index ed26b94e..39134199 100644
--- a/deepof/model_utils.py
+++ b/deepof/model_utils.py
@@ -9,6 +9,8 @@ Functions and general utilities for the deepof tensorflow models. See documentat
 """
 
 from itertools import combinations
+from typing import Any, Tuple
+
 from tensorflow.keras import backend as K
 from tensorflow.keras.constraints import Constraint
 from tensorflow.keras.layers import Layer
@@ -136,7 +138,7 @@ def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
 
 
 @tf.function
-def compute_mmd(tensors: tuple) -> tf.Tensor:
+def compute_mmd(tensors: Tuple[Any, Any]) -> tf.Tensor:
     """
 
         Computes the MMD between the two specified vectors using a gaussian kernel.
@@ -317,7 +319,9 @@ class DenseTranspose(Layer):
         """Updates Layer's build method"""
 
         self.biases = self.add_weight(
-            name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
+            name="bias",
+            shape=self.dense.get_input_at(-1).get_shape().as_list(),
+            initializer="zeros",
         )
         super().build(batch_input_shape)
 
@@ -390,6 +394,7 @@ class MMDiscrepancyLayer(Layer):
         """Updates Layer's call method"""
 
         true_samples = self.prior.sample(self.batch_size)
+        # noinspection PyTypeChecker
         mmd_batch = self.beta * compute_mmd((true_samples, z))
         self.add_loss(K.mean(mmd_batch), inputs=z)
         self.add_metric(mmd_batch, aggregation="mean", name="mmd")
@@ -428,7 +433,7 @@ class Gaussian_mixture_overlap(Layer):
         dists = []
         for k in range(self.n_components):
             locs = (target[..., : self.lat_dims, k],)
-            scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
+            scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
 
             dists.append(
                 tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
diff --git a/deepof/models.py b/deepof/models.py
index 84c2a877..66822d07 100644
--- a/deepof/models.py
+++ b/deepof/models.py
@@ -507,9 +507,7 @@ class SEQ_2_SEQ_GMVAE:
         encoder = BatchNormalization()(encoder)
 
         # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
-        z_cat = Dense(self.number_of_components, activation="softmax",)(
-            encoder
-        )
+        z_cat = Dense(self.number_of_components, activation="softmax",)(encoder)
         z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(z_cat)
         z_gauss = Dense(
             deepof.model_utils.tfpl.IndependentNormal.params_size(
@@ -535,7 +533,7 @@ class SEQ_2_SEQ_GMVAE:
                     tfd.Independent(
                         tfd.Normal(
                             loc=gauss[1][..., : self.ENCODING, k],
-                            scale=softplus(gauss[1][..., self.ENCODING:, k]),
+                            scale=softplus(gauss[1][..., self.ENCODING :, k]),
                         ),
                         reinterpreted_batch_ndims=1,
                     )
diff --git a/deepof/train_utils.py b/deepof/train_utils.py
index befc1192..e1810f3d 100644
--- a/deepof/train_utils.py
+++ b/deepof/train_utils.py
@@ -61,7 +61,12 @@ def load_treatments(train_path):
 
 
 def get_callbacks(
-    X_train: np.array, batch_size: int, cp: bool, variational: bool, predictor: float, loss: str,
+    X_train: np.array,
+    batch_size: int,
+    cp: bool,
+    variational: bool,
+    predictor: float,
+    loss: str,
 ) -> List[Union[Any]]:
     """Generates callbacks for model training, including:
         - run_ID: run name, with coarse parameter details;
-- 
GitLab