Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Lucas Miranda
deepOF
Commits
ca53fa78
Commit
ca53fa78
authored
Sep 17, 2020
by
lucas_miranda
Browse files
Added tests for model_utils.py
parent
a0ffa747
Changes
3
Hide whitespace changes
Inline
Side-by-side
deepof/model_utils.py
View file @
ca53fa78
...
...
@@ -118,15 +118,23 @@ def compute_mmd(tensors: tuple) -> tf.Tensor:
# Custom auxiliary classes
class
OneCycleScheduler
(
tf
.
keras
.
callbacks
.
Callback
):
class
one_cycle_scheduler
(
tf
.
keras
.
callbacks
.
Callback
):
"""
One cycle learning rate scheduler.
Based on https://arxiv.org/pdf/1506.01186.pdf
"""
def
__init__
(
self
,
iterations
,
max_rate
,
start_rate
=
None
,
last_iterations
=
None
,
last_rate
=
None
,
iterations
:
int
,
max_rate
:
float
,
start_rate
:
float
=
None
,
last_iterations
:
int
=
None
,
last_rate
:
float
=
None
,
):
super
().
__init__
()
self
.
iterations
=
iterations
self
.
max_rate
=
max_rate
self
.
start_rate
=
start_rate
or
max_rate
/
10
...
...
@@ -135,10 +143,12 @@ class OneCycleScheduler(tf.keras.callbacks.Callback):
self
.
last_rate
=
last_rate
or
self
.
start_rate
/
1000
self
.
iteration
=
0
def
_interpolate
(
self
,
iter1
,
iter2
,
rate1
,
rate2
)
:
def
_interpolate
(
self
,
iter1
:
int
,
iter2
:
int
,
rate1
:
float
,
rate2
:
float
)
->
float
:
return
(
rate2
-
rate1
)
*
(
self
.
iteration
-
iter1
)
/
(
iter2
-
iter1
)
+
rate1
def
on_batch_begin
(
self
,
batch
,
logs
):
# noinspection PyMethodOverriding,PyTypeChecker
def
on_batch_begin
(
self
,
batch
:
int
,
logs
):
""" Defines computations to perform for each batch """
if
self
.
iteration
<
self
.
half_iteration
:
rate
=
self
.
_interpolate
(
0
,
self
.
half_iteration
,
self
.
start_rate
,
self
.
max_rate
...
...
examples/model_training.py
View file @
ca53fa78
...
...
@@ -438,7 +438,9 @@ for run in range(runs):
),
)
onecycle
=
OneCycleScheduler
(
X_train
.
shape
[
0
]
//
batch_size
*
250
,
max_rate
=
0.005
,)
onecycle
=
one_cycle_scheduler
(
X_train
.
shape
[
0
]
//
batch_size
*
250
,
max_rate
=
0.005
,
)
if
not
variational
:
encoder
,
decoder
,
ae
=
SEQ_2_SEQ_AE
(
X_train
.
shape
,
**
hparams
).
build
()
...
...
tests/test_model_utils.py
View file @
ca53fa78
...
...
@@ -16,6 +16,11 @@ import deepof.model_utils
import
tensorflow
as
tf
from
tensorflow.python.framework.ops
import
EagerTensor
# For coverage.py to work with @tf.function decorated functions and methods,
# graph execution is disabled when running this script with pytest
tf
.
config
.
experimental_run_functions_eagerly
(
True
)
@
settings
(
deadline
=
None
)
@
given
(
...
...
@@ -35,11 +40,11 @@ def test_far_away_uniform_initialiser(shape):
@
settings
(
deadline
=
None
)
@
given
(
tensor
=
arrays
(
shape
=
(
10
,
10
),
dtype
=
float
,
unique
=
True
,
elements
=
st
.
floats
(
min_value
=-
300
,
max_value
=
300
),
),
shape
=
(
10
,
10
),
dtype
=
float
,
unique
=
True
,
elements
=
st
.
floats
(
min_value
=-
300
,
max_value
=
300
),
),
)
def
test_compute_mmd
(
tensor
):
...
...
@@ -53,12 +58,14 @@ def test_compute_mmd(tensor):
assert
null_kernel
==
0
#
#
# @settings(deadline=None)
# @given()
# def test_onecyclescheduler():
# pass
def
test_one_cycle_scheduler
():
cycle1
=
deepof
.
model_utils
.
one_cycle_scheduler
(
iterations
=
5
,
max_rate
=
1.0
,
start_rate
=
0.1
,
last_iterations
=
2
,
last_rate
=
0.3
)
assert
type
(
cycle1
.
_interpolate
(
1
,
2
,
0.2
,
0.5
))
==
float
#
#
# @settings(deadline=None)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment