Commit f00205fc authored by lucas_miranda's avatar lucas_miranda
Browse files

updated dependencies

parent 0c36a71a
Pipeline #87431 failed with stage
in 29 minutes and 45 seconds
...@@ -9,6 +9,8 @@ Functions and general utilities for the deepof tensorflow models. See documentat ...@@ -9,6 +9,8 @@ Functions and general utilities for the deepof tensorflow models. See documentat
""" """
from itertools import combinations from itertools import combinations
from typing import Any, Tuple
from tensorflow.keras import backend as K from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer from tensorflow.keras.layers import Layer
...@@ -136,7 +138,7 @@ def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: ...@@ -136,7 +138,7 @@ def compute_kernel(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
@tf.function @tf.function
def compute_mmd(tensors: tuple) -> tf.Tensor: def compute_mmd(tensors: Tuple[Any, Any]) -> tf.Tensor:
""" """
Computes the MMD between the two specified vectors using a gaussian kernel. Computes the MMD between the two specified vectors using a gaussian kernel.
...@@ -317,7 +319,9 @@ class DenseTranspose(Layer): ...@@ -317,7 +319,9 @@ class DenseTranspose(Layer):
"""Updates Layer's build method""" """Updates Layer's build method"""
self.biases = self.add_weight( self.biases = self.add_weight(
name="bias", shape=[self.dense.input_shape[-1]], initializer="zeros" name="bias",
shape=self.dense.get_input_at(-1).get_shape().as_list(),
initializer="zeros",
) )
super().build(batch_input_shape) super().build(batch_input_shape)
...@@ -390,6 +394,7 @@ class MMDiscrepancyLayer(Layer): ...@@ -390,6 +394,7 @@ class MMDiscrepancyLayer(Layer):
"""Updates Layer's call method""" """Updates Layer's call method"""
true_samples = self.prior.sample(self.batch_size) true_samples = self.prior.sample(self.batch_size)
# noinspection PyTypeChecker
mmd_batch = self.beta * compute_mmd((true_samples, z)) mmd_batch = self.beta * compute_mmd((true_samples, z))
self.add_loss(K.mean(mmd_batch), inputs=z) self.add_loss(K.mean(mmd_batch), inputs=z)
self.add_metric(mmd_batch, aggregation="mean", name="mmd") self.add_metric(mmd_batch, aggregation="mean", name="mmd")
...@@ -428,7 +433,7 @@ class Gaussian_mixture_overlap(Layer): ...@@ -428,7 +433,7 @@ class Gaussian_mixture_overlap(Layer):
dists = [] dists = []
for k in range(self.n_components): for k in range(self.n_components):
locs = (target[..., : self.lat_dims, k],) locs = (target[..., : self.lat_dims, k],)
scales = tf.keras.activations.softplus(target[..., self.lat_dims :, k]) scales = tf.keras.activations.softplus(target[..., self.lat_dims:, k])
dists.append( dists.append(
tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1]) tfd.BatchReshape(tfd.MultivariateNormalDiag(locs, scales), [-1])
......
...@@ -507,9 +507,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -507,9 +507,7 @@ class SEQ_2_SEQ_GMVAE:
encoder = BatchNormalization()(encoder) encoder = BatchNormalization()(encoder)
# encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder) # encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
z_cat = Dense(self.number_of_components, activation="softmax",)( z_cat = Dense(self.number_of_components, activation="softmax",)(encoder)
encoder
)
z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(z_cat) z_cat = deepof.model_utils.Entropy_regulariser(self.entropy_reg_weight)(z_cat)
z_gauss = Dense( z_gauss = Dense(
deepof.model_utils.tfpl.IndependentNormal.params_size( deepof.model_utils.tfpl.IndependentNormal.params_size(
...@@ -535,7 +533,7 @@ class SEQ_2_SEQ_GMVAE: ...@@ -535,7 +533,7 @@ class SEQ_2_SEQ_GMVAE:
tfd.Independent( tfd.Independent(
tfd.Normal( tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k], loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING:, k]), scale=softplus(gauss[1][..., self.ENCODING :, k]),
), ),
reinterpreted_batch_ndims=1, reinterpreted_batch_ndims=1,
) )
......
...@@ -61,7 +61,12 @@ def load_treatments(train_path): ...@@ -61,7 +61,12 @@ def load_treatments(train_path):
def get_callbacks( def get_callbacks(
X_train: np.array, batch_size: int, cp: bool, variational: bool, predictor: float, loss: str, X_train: np.array,
batch_size: int,
cp: bool,
variational: bool,
predictor: float,
loss: str,
) -> List[Union[Any]]: ) -> List[Union[Any]]:
"""Generates callbacks for model training, including: """Generates callbacks for model training, including:
- run_ID: run name, with coarse parameter details; - run_ID: run name, with coarse parameter details;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment