Commit b51981dd authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored train_utils.py

parent eb008013
......@@ -139,7 +139,7 @@ def get_callbacks(
return callbacks
def log_hyperparameters():
def log_hyperparameters(phenotype_class):
"""Blueprint for hyperparameter and metric logging in tensorboard during hyperparameter tuning"""
logparams = [
......@@ -288,7 +288,7 @@ def autoencoder_fitting(
# Logs hyperparameters to tensorboard
if log_hparams:
logparams, metrics = log_hyperparameters()
logparams, metrics = log_hyperparameters(phenotype_class)
with tf.summary.create_file_writer(
os.path.join(output_path, "hparams", run_ID)
......
......@@ -9,7 +9,7 @@ Testing module for deepof.train_utils
"""
from hypothesis import given
from hypothesis import settings
from hypothesis import settings, reproduce_failure
from hypothesis import strategies as st
from hypothesis.extra.numpy import arrays
import deepof.data
......@@ -127,10 +127,11 @@ def test_autoencoder_fitting(
batch_size=batch_size,
encoding_size=encoding_size,
epochs=1,
log_history=True,
log_hparams=True,
n_components=k,
loss=loss,
phenotype_class=pheno_class,
phenotype_class=0,
predictor=predictor,
variational=variational,
)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment