Commit f00205fc authored by lucas_miranda's avatar lucas_miranda
Browse files

updated dependencies

parent 0c36a71a
Pipeline #87431 failed with stage
in 29 minutes and 45 seconds
......@@ -9,6 +9,8 @@ Functions and general utilities for the deepof tensorflow models. See documentat
"""
from itertools import combinations
from typing import Any, Tuple
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
......@@ -136,7 +138,7 @@ def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
@tf.function
def compute_mmd(tensors: tuple) -> tf.Tensor:
def compute_mmd(tensors: Tuple[Any, Any]) -> tf.Tensor:
"""
Computes the MMD between the two specified vectors using a gaussian kernel.
......@@ -317,7 +319,9 @@ class DenseTranspose(Layer):
"""Updates Layer's build method"""
self.biases = self.add_weight(
name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros"
name="bias",
shape=self.dense.get_input_at(-1).get_shape().as_list(),
initializer="zeros",
)
super().build(batch_input_shape)
......@@ -390,6 +394,7 @@ class MMDiscrepancyLayer(Layer):
"""Updates Layer's call method"""
true_samples = self.prior.sample(self.batch_size)
# noinspection PyTypeChecker
mmd_batch = self.beta * compute_mmd((true_samples, z))
self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd")
......@@ -428,7 +433,7 @@ class Gaussian_mixture_overlap(Layer):
dists = []
for k in range(self.n_components):
locs = (target[..., : self.lat_dims, k],)
scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k])
scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
dists.append(
tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
......
......@@ -507,9 +507,7 @@ class SEQ_2_SEQ_GMVAE:
encoder = BatchNormalization()(encoder)
# encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
z_cat = Dense(self.number_of_components, activation="softmax",)(
encoder
)
z_cat = Dense(self.number_of_components, activation="softmax",)(encoder)
z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(z_cat)
z_gauss = Dense(
deepof.model_utils.tfpl.IndependentNormal.params_size(
......@@ -535,7 +533,7 @@ class SEQ_2_SEQ_GMVAE:
tfd.Independent(
tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING:, k]),
scale=softplus(gauss[1][..., self.ENCODING :, k]),
),
reinterpreted_batch_ndims=1,
)
......
......@@ -61,7 +61,12 @@ def load_treatments(train_path):
def get_callbacks(
X_train: np.array, batch_size: int, cp: bool, variational: bool, predictor: float, loss: str,
X_train: np.array,
batch_size: int,
cp: bool,
variational: bool,
predictor: float,
loss: str,
) -> List[Union[Any]]:
"""Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment