Commit f0fd390c authored by lucas_miranda's avatar lucas_miranda
Browse files

Refactored train_model.py

parent 93ba8a91
Pipeline #83731 passed with stage
in 33 minutes and 21 seconds
......@@ -542,7 +542,7 @@ class SEQ_2_SEQ_GMVAE:
deepof.model_utils.tfd.Independent(
deepof.model_utils.tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING :, k]),
scale=softplus(gauss[1][..., self.ENCODING:, k]),
),
reinterpreted_batch_ndims=1,
)
......@@ -641,14 +641,8 @@ class SEQ_2_SEQ_GMVAE:
_x_decoded_mean = TimeDistributed(Dense(input_shape[2]))(_generator)
generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
def huber_loss(x_, x_decoded_mean_): # pragma: no cover
"""Computes huber loss with a fixed delta"""
huber = Huber(reduction="sum", delta=self.delta)
return input_shape[1:] * huber(x_, x_decoded_mean_)
gmvaep.compile(
loss=huber_loss,
loss=Huber(reduction="sum", delta=self.delta),
optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
metrics=["mae"],
loss_weights=([1, self.predictor] if self.predictor > 0 else [1]),
......
......@@ -13,6 +13,7 @@ from deepof.data import *
from deepof.models import *
from deepof.utils import *
from train_utils import *
from tensorboard.plugins.hparams import api as hp
from tensorflow import keras
parser = argparse.ArgumentParser(
......@@ -61,14 +62,6 @@ parser.add_argument(
type=str2bool,
default=False,
)
parser.add_argument(
"--hypermodel",
"-m",
help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, S2SVAEP, "
"S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
type=str,
default="S2SVAE",
)
parser.add_argument(
"--hyperparameter-tuning",
"-tune",
......@@ -183,7 +176,6 @@ bayopt_trials = args.bayopt
exclude_bodyparts = tuple(args.exclude_bodyparts.split(","))
gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters
hyp = args.hypermodel
input_type = args.input_type
k = args.components
kl_wu = args.kl_warmup
......@@ -270,17 +262,12 @@ input_dict_train = {
}
print("Preprocessing data...")
for key, value in input_dict_train.items():
input_dict_train[key] = batch_preprocess(value)
print("Done!")
print("Creating training and validation sets...")
preprocessed = batch_preprocess(input_dict_train[input_type])
# Get training and validation sets
X_train = input_dict_train[input_type][0]
X_val = input_dict_train[input_type][1]
X_train = preprocessed[0]
X_val = preprocessed[1]
print("Done!")
# Proceed with training mode. Fit autoencoder with the same parameters,
# as many times as specified by runs
if not tune:
......@@ -384,6 +371,8 @@ if not tune:
else:
# Runs hyperparameter tuning with the specified parameters and saves the results
hyp = "S2SGMVAE" if variational else "S2SAE"
run_ID, tensorboard_callback, cp_callback, onecycle = get_callbacks(
X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu
)
......@@ -420,4 +409,4 @@ else:
# TODO:
# - Investigate how goussian filters affect reproducibility (in a systematic way)
# - Investigate how smoothing affects reproducibility (in a systematic way)
# - Check if MCDropout effectively enhances reproducibility or not
# - Check if MCDropout effectively enhances reproducibility or not
\ No newline at end of file
......@@ -146,10 +146,9 @@ def tune_search(
"""
print(callbacks)
tensorboard_callback, cp_callback, onecycle = callbacks
if hypermodel == "S2SAE":
if hypermodel == "S2SAE": # pragma: no cover
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape)
elif hypermodel == "S2SGMVAE":
......@@ -179,9 +178,9 @@ def tune_search(
tuner.search(
train,
train,
train if predictor == 0 else [train[:-1], train[1:]],
epochs=n_epochs,
validation_data=(test, test),
validation_data=(test, test if predictor == 0 else [test[:-1], test[1:]]),
verbose=1,
batch_size=256,
callbacks=[
......
......@@ -80,12 +80,12 @@ def str2bool(v: str) -> bool:
"""
if isinstance(v, bool):
return v
return v # pragma: no cover
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
else: # pragma: no cover
raise argparse.ArgumentTypeError("Boolean compatible value expected.")
......
%% Cell type:code id: tags:
``` python
import os
os.chdir(os.path.dirname("../"))
```
%% Cell type:code id: tags:
``` python
import deepof.data
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tqdm.notebook as tqdm
from ipywidgets import interact
```
%% Cell type:code id: tags:
``` python
%%time
deepof_main = deepof.data.project(path=os.path.join("..","..","Desktop","deepof-data"),
exclude_bodyparts=["Tail_tip"],
smooth_alpha=0.99,
arena_dims=[380])
```
%% Output
CPU times: user 14.1 s, sys: 2.63 s, total: 16.7 s
Wall time: 4.57 s
%% Cell type:markdown id: tags:
# Run project
%% Cell type:code id: tags:
``` python
%%time
deepof_main = deepof_main.run(verbose=True)
print(deepof_main)
```
%% Output
Loading trajectories...
Smoothing trajectories...
Computing distances...
Computing angles...
Done!
deepof analysis of 109 videos
CPU times: user 35.6 s, sys: 5.4 s, total: 41 s
Wall time: 48.5 s
%% Cell type:markdown id: tags:
# Check tagging quality
%% Cell type:code id: tags:
``` python
all_quality = pd.concat([tab for tab in deepof_main.get_quality().values()]).droplevel("scorer", axis=1)
```
%% Cell type:code id: tags:
``` python
all_quality.boxplot(rot=45)
plt.ylim(0.99985, 1.00001)
plt.show()
```
%% Output
%% Cell type:code id: tags:
``` python
@interact(quality_top=(0., 1., 0.01))
def low_quality_tags(quality_top):
pd.DataFrame(pd.melt(all_quality).groupby("bodyparts").value.apply(
lambda y: sum(y<quality_top) / len(y) * 100)
).sort_values(by="value", ascending=False).plot.bar(rot=45)
plt.xlabel("body part")
plt.ylabel("Tags with quality under {} (%)".format(quality_top))
plt.tight_layout()
plt.legend([])
plt.show()
```
%% Output
%% Cell type:markdown id: tags:
# Generate coords
%% Cell type:code id: tags:
``` python
%%time
deepof_coords = deepof_main.get_coords(center="Center", polar=False, speed=0, align="Spine_1")
```
%% Output
CPU times: user 3.22 s, sys: 484 ms, total: 3.7 s
Wall time: 4.21 s
%% Cell type:markdown id: tags:
# Visualization
%% Cell type:code id: tags:
``` python
heat = deepof_coords.plot_heatmaps(['Nose'], i=0, dpi=40)