Commit ca53fa78 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added tests for model_utils.py

parent a0ffa747
......@@ -118,15 +118,23 @@ def compute_mmd(tensors: tuple) -> tf.Tensor:
# Custom auxiliary classes
class OneCycleScheduler(tf.keras.callbacks.Callback):
class one_cycle_scheduler(tf.keras.callbacks.Callback):
"""
One cycle learning rate scheduler.
Based on https://arxiv.org/pdf/1506.01186.pdf
"""
def __init__(
self,
iterations,
max_rate,
start_rate=None,
last_iterations=None,
last_rate=None,
iterations: int,
max_rate: float,
start_rate: float = None,
last_iterations: int = None,
last_rate: float = None,
):
super().__init__()
self.iterations = iterations
self.max_rate = max_rate
self.start_rate = start_rate or max_rate / 10
......@@ -135,10 +143,12 @@ class OneCycleScheduler(tf.keras.callbacks.Callback):
self.last_rate = last_rate or self.start_rate / 1000
self.iteration = 0
def _interpolate(self, iter1, iter2, rate1, rate2):
def _interpolate(self, iter1: int, iter2: int, rate1: float, rate2: float) -> float:
return (rate2 - rate1) * (self.iteration - iter1) / (iter2 - iter1) + rate1
def on_batch_begin(self, batch, logs):
# noinspection PyMethodOverriding,PyTypeChecker
def on_batch_begin(self, batch: int, logs):
""" Defines computations to perform for each batch """
if self.iteration < self.half_iteration:
rate = self._interpolate(
0, self.half_iteration, self.start_rate, self.max_rate
......
......@@ -438,7 +438,9 @@ for run in range(runs):
),
)
onecycle = OneCycleScheduler(X_train.shape[0] // batch_size * 250, max_rate=0.005,)
onecycle = one_cycle_scheduler(
X_train.shape[0] // batch_size * 250, max_rate=0.005,
)
if not variational:
encoder, decoder, ae = SEQ_2_SEQ_AE(X_train.shape, **hparams).build()
......
......@@ -16,6 +16,11 @@ import deepof.model_utils
import tensorflow as tf
from tensorflow.python.framework.ops import EagerTensor
# For coverage.py to work with @tf.function decorated functions and methods,
# graph execution is disabled when running this script with pytest
tf.config.experimental_run_functions_eagerly(True)
@settings(deadline=None)
@given(
......@@ -35,11 +40,11 @@ def test_far_away_uniform_initialiser(shape):
@settings(deadline=None)
@given(
tensor=arrays(
shape=(10, 10),
dtype=float,
unique=True,
elements=st.floats(min_value=-300, max_value=300),
),
shape=(10, 10),
dtype=float,
unique=True,
elements=st.floats(min_value=-300, max_value=300),
),
)
def test_compute_mmd(tensor):
......@@ -53,12 +58,14 @@ def test_compute_mmd(tensor):
assert null_kernel == 0
#
#
# @settings(deadline=None)
# @given()
# def test_onecyclescheduler():
# pass
def test_one_cycle_scheduler():
cycle1 = deepof.model_utils.one_cycle_scheduler(
iterations=5, max_rate=1.0, start_rate=0.1, last_iterations=2, last_rate=0.3
)
assert type(cycle1._interpolate(1, 2, 0.2, 0.5)) == float
#
#
# @settings(deadline=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment