Commit 7b8696ea authored by lucas_miranda's avatar lucas_miranda
Browse files

added inplace alignment on deepof.data.coordinates.get_coords() on data.py

parent ff1ec1ed
Pipeline #83460 passed with stage
in 17 minutes and 36 seconds
......@@ -12,6 +12,7 @@ from itertools import combinations
from tensorflow.keras import backend as K
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.layers import Layer
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
......@@ -20,6 +21,52 @@ tfpl = tfp.layers
# Helper functions
class exponential_learning_rate(tf.keras.callbacks.Callback):
"""Simple class that allows to grow learning rate exponentially during training"""
def __init__(self, factor):
super().__init__()
self.factor = factor
self.rates = []
self.losses = []
# noinspection PyMethodOverriding
def on_batch_end(self, batch, logs):
"""This callback acts after processing each batch"""
self.rates.append(K.get_value(self.model.optimizer.lr))
self.losses.append(logs["loss"])
K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)
def find_learning_rate(
model, X, y, epochs=1, batch_size=32, min_rate=10 ** -5, max_rate=10
):
"""Trains the provided model for an epoch with an exponentially increasing learning rate"""
init_weights = model.get_weights()
iterations = len(X) // batch_size * epochs
factor = K.exp(K.log(max_rate / min_rate) / iterations)
init_lr = K.get_value(model.optimizer.lr)
K.set_value(model.optimizer.lr, min_rate)
exp_lr = exponential_learning_rate(factor)
model.fit(X, y, epochs=epochs, batch_size=batch_size, callbacks=[exp_lr])
K.set_value(model.optimizer.lr, init_lr)
model.set_weights(init_weights)
return exp_lr.rates, exp_lr.losses
def plot_lr_vs_loss(rates, losses): # pragma: no cover
"""Plots learing rate versus the loss function of the model"""
plt.plot(rates, losses)
plt.gca().set_xscale("log")
plt.hlines(min(losses), min(rates), max(rates))
plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])
plt.xlabel("Learning rate")
plt.ylabel("Loss")
@tf.function
def far_away_uniform_initialiser(
shape: tuple, minval: int = 0, maxval: int = 15, iters: int = 100000
......
......@@ -227,6 +227,8 @@ class SEQ_2_SEQ_AE:
metrics=["mae"],
)
model.build(input_shape)
return encoder, decoder, model
......@@ -652,6 +654,8 @@ class SEQ_2_SEQ_GMVAE:
loss_weights=([1, self.predictor] if self.predictor > 0 else [1]),
)
gmvaep.build(input_shape)
return (
encoder,
generator,
......
This diff is collapsed.
......@@ -49,7 +49,7 @@ parser.add_argument(
"-v",
help="Sets the model to train to a variational Bayesian autoencoder. Defaults to True",
default=True,
type=deepof.preprocess.str2bool,
type=deepof.data.str2bool,
)
parser.add_argument(
"--loss",
......
......@@ -201,6 +201,7 @@ def test_get_table_dicts(nodes, ego, sampler):
prun = prun.run(verbose=False)
algn = sampler.draw(st.one_of(st.just(False), st.just("Nose")))
inplace = sampler.draw(st.booleans())
polar = st.one_of(st.just(True), st.just(False))
speed = sampler.draw(st.integers(min_value=0, max_value=5))
......@@ -209,6 +210,7 @@ def test_get_table_dicts(nodes, ego, sampler):
polar=polar,
length=sampler.draw(st.one_of(st.just(False), st.just("00:10:00"))),
align=algn,
align_inplace=inplace,
)
speeds = prun.get_coords(
center=sampler.draw(st.one_of(st.just("arena"), st.just("Center"))),
......
......@@ -244,3 +244,19 @@ def test_entropy_regulariser():
fit = test_model.fit(X, y, epochs=10, batch_size=100)
assert type(fit) == tf.python.keras.callbacks.History
def test_find_learning_rate():
X = np.random.uniform(0, 10, [1500, 5])
y = np.random.randint(0, 2, [1500, 1])
test_model = tf.keras.Sequential()
test_model.add(tf.keras.layers.Dense(1))
test_model.add(deepof.model_utils.Entropy_regulariser(1.0))
test_model.compile(
loss=tf.keras.losses.binary_crossentropy, optimizer=tf.keras.optimizers.SGD(),
)
test_model.build(X.shape)
deepof.model_utils.find_learning_rate(test_model, X, y)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment