Skip to content
Snippets Groups Projects
Commit f0fd390c authored by Lucas Miranda's avatar Lucas Miranda
Browse files

Refactored train_model.py

parent 93ba8a91
Branches
Tags
No related merge requests found
Pipeline #83731 passed
...@@ -641,14 +641,8 @@ class SEQ_2_SEQ_GMVAE: ...@@ -641,14 +641,8 @@ class SEQ_2_SEQ_GMVAE:
_x_decoded_mean = TimeDistributed(Dense(input_shape[2]))(_generator) _x_decoded_mean = TimeDistributed(Dense(input_shape[2]))(_generator)
generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator") generator = Model(g, _x_decoded_mean, name="SEQ_2_SEQ_VGenerator")
def huber_loss(x_, x_decoded_mean_): # pragma: no cover
"""Computes huber loss with a fixed delta"""
huber = Huber(reduction="sum", delta=self.delta)
return input_shape[1:] * huber(x_, x_decoded_mean_)
gmvaep.compile( gmvaep.compile(
loss=huber_loss, loss=Huber(reduction="sum", delta=self.delta),
optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,), optimizer=Nadam(lr=self.learn_rate, clipvalue=0.5,),
metrics=["mae"], metrics=["mae"],
loss_weights=([1, self.predictor] if self.predictor > 0 else [1]), loss_weights=([1, self.predictor] if self.predictor > 0 else [1]),
......
...@@ -13,6 +13,7 @@ from deepof.data import * ...@@ -13,6 +13,7 @@ from deepof.data import *
from deepof.models import * from deepof.models import *
from deepof.utils import * from deepof.utils import *
from train_utils import * from train_utils import *
from tensorboard.plugins.hparams import api as hp
from tensorflow import keras from tensorflow import keras
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -61,14 +62,6 @@ parser.add_argument( ...@@ -61,14 +62,6 @@ parser.add_argument(
type=str2bool, type=str2bool,
default=False, default=False,
) )
parser.add_argument(
"--hypermodel",
"-m",
help="Selects which hypermodel to use. It must be one of S2SAE, S2SVAE, S2SVAE-ELBO, S2SVAE-MMD, S2SVAEP, "
"S2SVAEP-ELBO and S2SVAEP-MMD. Please refer to the documentation for details on each option.",
type=str,
default="S2SVAE",
)
parser.add_argument( parser.add_argument(
"--hyperparameter-tuning", "--hyperparameter-tuning",
"-tune", "-tune",
...@@ -183,7 +176,6 @@ bayopt_trials = args.bayopt ...@@ -183,7 +176,6 @@ bayopt_trials = args.bayopt
exclude_bodyparts = tuple(args.exclude_bodyparts.split(",")) exclude_bodyparts = tuple(args.exclude_bodyparts.split(","))
gaussian_filter = args.gaussian_filter gaussian_filter = args.gaussian_filter
hparams = args.hyperparameters hparams = args.hyperparameters
hyp = args.hypermodel
input_type = args.input_type input_type = args.input_type
k = args.components k = args.components
kl_wu = args.kl_warmup kl_wu = args.kl_warmup
...@@ -270,17 +262,12 @@ input_dict_train = { ...@@ -270,17 +262,12 @@ input_dict_train = {
} }
print("Preprocessing data...") print("Preprocessing data...")
for key, value in input_dict_train.items(): preprocessed = batch_preprocess(input_dict_train[input_type])
input_dict_train[key] = batch_preprocess(value)
print("Done!")
print("Creating training and validation sets...")
# Get training and validation sets # Get training and validation sets
X_train = input_dict_train[input_type][0] X_train = preprocessed[0]
X_val = input_dict_train[input_type][1] X_val = preprocessed[1]
print("Done!") print("Done!")
# Proceed with training mode. Fit autoencoder with the same parameters, # Proceed with training mode. Fit autoencoder with the same parameters,
# as many times as specified by runs # as many times as specified by runs
if not tune: if not tune:
...@@ -384,6 +371,8 @@ if not tune: ...@@ -384,6 +371,8 @@ if not tune:
else: else:
# Runs hyperparameter tuning with the specified parameters and saves the results # Runs hyperparameter tuning with the specified parameters and saves the results
hyp = "S2SGMVAE" if variational else "S2SAE"
run_ID, tensorboard_callback, cp_callback, onecycle = get_callbacks( run_ID, tensorboard_callback, cp_callback, onecycle = get_callbacks(
X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu X_train, batch_size, variational, predictor, k, loss, kl_wu, mmd_wu
) )
......
...@@ -146,10 +146,9 @@ def tune_search( ...@@ -146,10 +146,9 @@ def tune_search(
""" """
print(callbacks)
tensorboard_callback, cp_callback, onecycle = callbacks tensorboard_callback, cp_callback, onecycle = callbacks
if hypermodel == "S2SAE": if hypermodel == "S2SAE": # pragma: no cover
hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape) hypermodel = deepof.hypermodels.SEQ_2_SEQ_AE(input_shape=train.shape)
elif hypermodel == "S2SGMVAE": elif hypermodel == "S2SGMVAE":
...@@ -179,9 +178,9 @@ def tune_search( ...@@ -179,9 +178,9 @@ def tune_search(
tuner.search( tuner.search(
train, train,
train, train if predictor == 0 else [train[:-1], train[1:]],
epochs=n_epochs, epochs=n_epochs,
validation_data=(test, test), validation_data=(test, test if predictor == 0 else [test[:-1], test[1:]]),
verbose=1, verbose=1,
batch_size=256, batch_size=256,
callbacks=[ callbacks=[
......
...@@ -80,12 +80,12 @@ def str2bool(v: str) -> bool: ...@@ -80,12 +80,12 @@ def str2bool(v: str) -> bool:
""" """
if isinstance(v, bool): if isinstance(v, bool):
return v return v # pragma: no cover
if v.lower() in ("yes", "true", "t", "y", "1"): if v.lower() in ("yes", "true", "t", "y", "1"):
return True return True
elif v.lower() in ("no", "false", "f", "n", "0"): elif v.lower() in ("no", "false", "f", "n", "0"):
return False return False
else: else: # pragma: no cover
raise argparse.ArgumentTypeError("Boolean compatible value expected.") raise argparse.ArgumentTypeError("Boolean compatible value expected.")
......
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import os import os
os.chdir(os.path.dirname("../")) os.chdir(os.path.dirname("../"))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import deepof.data import deepof.data
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import seaborn as sns import seaborn as sns
import tqdm.notebook as tqdm import tqdm.notebook as tqdm
from ipywidgets import interact from ipywidgets import interact
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
deepof_main = deepof.data.project(path=os.path.join("..","..","Desktop","deepof-data"), deepof_main = deepof.data.project(path=os.path.join("..","..","Desktop","deepof-data"),
exclude_bodyparts=["Tail_tip"], exclude_bodyparts=["Tail_tip"],
smooth_alpha=0.99, smooth_alpha=0.99,
arena_dims=[380]) arena_dims=[380])
``` ```
%% Output
CPU times: user 14.1 s, sys: 2.63 s, total: 16.7 s
Wall time: 4.57 s
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Run project # Run project
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
deepof_main = deepof_main.run(verbose=True) deepof_main = deepof_main.run(verbose=True)
print(deepof_main) print(deepof_main)
``` ```
%% Output
Loading trajectories...
Smoothing trajectories...
Computing distances...
Computing angles...
Done!
deepof analysis of 109 videos
CPU times: user 35.6 s, sys: 5.4 s, total: 41 s
Wall time: 48.5 s
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Check tagging quality # Check tagging quality
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
all_quality = pd.concat([tab for tab in deepof_main.get_quality().values()]).droplevel("scorer", axis=1) all_quality = pd.concat([tab for tab in deepof_main.get_quality().values()]).droplevel("scorer", axis=1)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
all_quality.boxplot(rot=45) all_quality.boxplot(rot=45)
plt.ylim(0.99985, 1.00001) plt.ylim(0.99985, 1.00001)
plt.show() plt.show()
``` ```
%% Output
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
@interact(quality_top=(0., 1., 0.01)) @interact(quality_top=(0., 1., 0.01))
def low_quality_tags(quality_top): def low_quality_tags(quality_top):
pd.DataFrame(pd.melt(all_quality).groupby("bodyparts").value.apply( pd.DataFrame(pd.melt(all_quality).groupby("bodyparts").value.apply(
lambda y: sum(y<quality_top) / len(y) * 100) lambda y: sum(y<quality_top) / len(y) * 100)
).sort_values(by="value", ascending=False).plot.bar(rot=45) ).sort_values(by="value", ascending=False).plot.bar(rot=45)
plt.xlabel("body part") plt.xlabel("body part")
plt.ylabel("Tags with quality under {} (%)".format(quality_top)) plt.ylabel("Tags with quality under {} (%)".format(quality_top))
plt.tight_layout() plt.tight_layout()
plt.legend([]) plt.legend([])
plt.show() plt.show()
``` ```
%% Output
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Generate coords # Generate coords
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
deepof_coords = deepof_main.get_coords(center="Center", polar=False, speed=0, align="Spine_1") deepof_coords = deepof_main.get_coords(center="Center", polar=False, speed=0, align="Spine_1")
``` ```
%% Output
CPU times: user 3.22 s, sys: 484 ms, total: 3.7 s
Wall time: 4.21 s
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Visualization # Visualization
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
heat = deepof_coords.plot_heatmaps(['Nose'], i=0, dpi=40) heat = deepof_coords.plot_heatmaps(['Nose'], i=0, dpi=40)
``` ```
%% Output
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
deepof_coords = deepof_main.get_coords(center="Center", polar=False, speed=0, align="Spine_1", align_inplace=True) deepof_coords = deepof_main.get_coords(center="Center", polar=False, speed=0, align="Spine_1", align_inplace=True)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Draft: function to produce a video with the animal in motion using cv2 # Draft: function to produce a video with the animal in motion using cv2
import cv2 import cv2
w=400 w=400
h=400 h=400
factor=2.5 factor=2.5
# Instantiate video # Instantiate video
writer = cv2.VideoWriter() writer = cv2.VideoWriter()
writer.open("test_video.avi", cv2.VideoWriter_fourcc(*"MJPG"), writer.open("test_video.avi", cv2.VideoWriter_fourcc(*"MJPG"),
24, (int(w*factor), int(h*factor)), True) 24, (int(w*factor), int(h*factor)), True)
for frame in tqdm.tqdm(range(100)): for frame in tqdm.tqdm(range(100)):
image=np.zeros((h,w,3),np.uint8) + 30 image=np.zeros((h,w,3),np.uint8) + 30
for bpart in deepof_coords["Test 10_s2"].columns.levels[0]: for bpart in deepof_coords["Test 10_s2"].columns.levels[0]:
try: try:
pos = ( (- int(deepof_coords["Test 10_s2"][bpart].loc[frame, "x"]) + w//2), pos = ( (- int(deepof_coords["Test 10_s2"][bpart].loc[frame, "x"]) + w//2),
(- int(deepof_coords["Test 10_s2"][bpart].loc[frame, "y"]) + h//2)) (- int(deepof_coords["Test 10_s2"][bpart].loc[frame, "y"]) + h//2))
cv2.circle(image, pos, 2, (0,0,255), -1) cv2.circle(image, pos, 2, (0,0,255), -1)
except KeyError: except KeyError:
continue continue
# draw skeleton # draw skeleton
def draw_line(start, end): def draw_line(start, end):
for bpart in end: for bpart in end:
cv2.line(image, tuple(- deepof_coords["Test 10_s2"][start].loc[frame,:].astype(int) + w//2), cv2.line(image, tuple(- deepof_coords["Test 10_s2"][start].loc[frame,:].astype(int) + w//2),
tuple(- deepof_coords["Test 10_s2"][bpart].loc[frame,:].astype(int) + h//2), (0,0,255), 1) tuple(- deepof_coords["Test 10_s2"][bpart].loc[frame,:].astype(int) + h//2), (0,0,255), 1)
draw_line("Nose", ["Left_ear", "Right_ear"]) draw_line("Nose", ["Left_ear", "Right_ear"])
draw_line("Spine_1", ["Left_ear", "Right_ear", "Left_fhip", "Right_fhip"]) draw_line("Spine_1", ["Left_ear", "Right_ear", "Left_fhip", "Right_fhip"])
draw_line("Spine_2", ["Spine_1", "Tail_base", "Left_bhip", "Right_bhip"]) draw_line("Spine_2", ["Spine_1", "Tail_base", "Left_bhip", "Right_bhip"])
#draw_line("Tail_1", ["Tail_base", "Tail_2"]) #draw_line("Tail_1", ["Tail_base", "Tail_2"])
#draw_line("Tail_tip", ["Tail_2"]) #draw_line("Tail_tip", ["Tail_2"])
image = cv2.resize(image, (0,0), fx = factor, fy = factor) image = cv2.resize(image, (0,0), fx = factor, fy = factor)
writer.write(image) writer.write(image)
writer.release() writer.release()
cv2.destroyAllWindows() cv2.destroyAllWindows()
``` ```
%% Output
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Preprocessing # Preprocessing
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
deepof_train, deepof_test = deepof_coords.preprocess(window_size=13, window_step=10, conv_filter=None, sigma=55, deepof_train, deepof_test = deepof_coords.preprocess(window_size=13, window_step=10, conv_filter=None, sigma=55,
shift=0, scale='standard', align='all', shuffle=True, test_videos=20) shift=0, scale='standard', align='all', shuffle=True, test_videos=20)
print("Train dataset shape: ", deepof_train.shape) print("Train dataset shape: ", deepof_train.shape)
print("Test dataset shape: ", deepof_test.shape) print("Test dataset shape: ", deepof_test.shape)
``` ```
%% Cell type:code id: tags: %% Output
``` python
n = 100
plt.scatter(deepof_train[:n,10,0], deepof_train[:n,10,1], label='Nose') Train dataset shape: (133517, 13, 24)
plt.scatter(deepof_train[:n,10,2], deepof_train[:n,10,3], label='Right ear') Test dataset shape: (30003, 13, 24)
plt.scatter(deepof_train[:n,10,4], deepof_train[:n,10,5], label='Right hips') CPU times: user 44.8 s, sys: 1.36 s, total: 46.2 s
plt.scatter(deepof_train[:n,10,6], deepof_train[:n,10,7], label='Left ear') Wall time: 46.7 s
plt.scatter(deepof_train[:n,10,8], deepof_train[:n,10,9], label='Left hips')
plt.scatter(deepof_train[:n,10,10], deepof_train[:n,10,11], label='Tail base')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()
```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Build models and get learning rate (1-cycle policy) # Build models and get learning rate (1-cycle policy)
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Seq 2 seq Variational Auto Encoder ### Seq 2 seq Variational Auto Encoder
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from datetime import datetime from datetime import datetime
import tensorflow.keras as k import tensorflow.keras as k
import tensorflow as tf import tensorflow as tf
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
NAME = 'Baseline_AE_512_wu10_slide10_gauss_fullval' NAME = 'Baseline_AE'
log_dir = os.path.abspath( log_dir = os.path.abspath(
"logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S")) "logs/fit/{}_{}".format(NAME, datetime.now().strftime("%Y%m%d-%H%M%S"))
) )
tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1) tensorboard_callback = k.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from deepof.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_GMVAE from deepof.models import SEQ_2_SEQ_AE, SEQ_2_SEQ_GMVAE
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
encoder, decoder, ae = SEQ_2_SEQ_AE().build(deepof_train.shape) encoder, decoder, ae = SEQ_2_SEQ_AE().build(deepof_train.shape)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
ae.summary() decoder.summary()
``` ```
%% Output
Model: "SEQ_2_SEQ_Decoder"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_transpose (DenseTransp multiple 1104
_________________________________________________________________
batch_normalization_5 (Batch multiple 256
_________________________________________________________________
dense_transpose_1 (DenseTran multiple 8384
_________________________________________________________________
batch_normalization_6 (Batch multiple 512
_________________________________________________________________
dense_transpose_2 (DenseTran multiple 33152
_________________________________________________________________
batch_normalization_7 (Batch multiple 1024
_________________________________________________________________
repeat_vector (RepeatVector) multiple 0
_________________________________________________________________
bidirectional_2 (Bidirection multiple 1050624
_________________________________________________________________
batch_normalization_8 (Batch multiple 2048
_________________________________________________________________
bidirectional_3 (Bidirection multiple 1574912
_________________________________________________________________
time_distributed (TimeDistri multiple 12312
=================================================================
Total params: 2,684,328
Trainable params: 2,682,408
Non-trainable params: 1,920
_________________________________________________________________
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
tf.keras.backend.clear_session() tf.keras.backend.clear_session()
encoder, generator, grouper, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_GMVAE(loss='ELBO', encoder, generator, grouper, gmvaep, kl_warmup_callback, mmd_warmup_callback = SEQ_2_SEQ_GMVAE(loss='ELBO',
number_of_components=5, number_of_components=5,
kl_warmup_epochs=10, kl_warmup_epochs=10,
mmd_warmup_epochs=10, mmd_warmup_epochs=10,
predictor=False).build(deepof_train.shape) predictor=False).build(deepof_train.shape)
gmvaep.summary() generator.summary()
``` ```
%% Output
Model: "SEQ_2_SEQ_VGenerator"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) [(None, 16)] 0
_________________________________________________________________
dense_2 (Dense) (None, 64) 1024
_________________________________________________________________
batch_normalization (BatchNo (None, 64) 256
_________________________________________________________________
dense_3 (Dense) (None, 128) 8192
_________________________________________________________________
batch_normalization_1 (Batch (None, 128) 512
_________________________________________________________________
repeat_vector (RepeatVector) (None, 13, 128) 0
_________________________________________________________________
bidirectional_2 (Bidirection (None, 13, 256) 262144
_________________________________________________________________
batch_normalization_2 (Batch (None, 13, 256) 1024
_________________________________________________________________
bidirectional_3 (Bidirection (None, 13, 512) 1048576
_________________________________________________________________
batch_normalization_3 (Batch (None, 13, 512) 2048
_________________________________________________________________
time_distributed (TimeDistri (None, 13, 24) 12312
=================================================================
Total params: 1,336,088
Trainable params: 1,334,168
Non-trainable params: 1,920
_________________________________________________________________
CPU times: user 4.11 s, sys: 102 ms, total: 4.21 s
Wall time: 4.17 s
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
batch_size = 512 batch_size = 512
rates, losses = deepof.model_utils.find_learning_rate(gmvaep, deepof_train[:512*10], deepof_test[:512*10], epochs=1, batch_size=batch_size) rates, losses = deepof.model_utils.find_learning_rate(gmvaep, deepof_train[:512*10], deepof_test[:512*10], epochs=1, batch_size=batch_size)
deepof.model_utils.plot_lr_vs_loss(rates, losses) deepof.model_utils.plot_lr_vs_loss(rates, losses)
plt.title("Learning rate tuning") plt.title("Learning rate tuning")
plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 1.4]) plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 1.4])
plt.show() plt.show()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Encoding plots # Encoding plots
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import umap import umap
from sklearn.manifold import TSNE from sklearn.manifold import TSNE
from sklearn.decomposition import PCA from sklearn.decomposition import PCA
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
import plotly.express as px import plotly.express as px
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
data = pttest data = pttest
samples = 15000 samples = 15000
montecarlo = 10 montecarlo = 10
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
weights = "GMVAE_components=30_loss=ELBO_kl_warmup=30_mmd_warmup=30_20200804-225526_final_weights.h5" weights = "GMVAE_components=30_loss=ELBO_kl_warmup=30_mmd_warmup=30_20200804-225526_final_weights.h5"
gmvaep.load_weights(weights) gmvaep.load_weights(weights)
if montecarlo: if montecarlo:
clusts = np.stack([grouper(data[:samples]) for sample in (tqdm(range(montecarlo)))]) clusts = np.stack([grouper(data[:samples]) for sample in (tqdm(range(montecarlo)))])
clusters = clusts.mean(axis=0) clusters = clusts.mean(axis=0)
clusters = np.argmax(clusters, axis=1) clusters = np.argmax(clusters, axis=1)
else: else:
clusters = grouper(data[:samples], training=False) clusters = grouper(data[:samples], training=False)
clusters = np.argmax(clusters, axis=1) clusters = np.argmax(clusters, axis=1)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def plot_encodings(data, samples, n, clusters, threshold): def plot_encodings(data, samples, n, clusters, threshold):
reducer = PCA(n_components=n) reducer = PCA(n_components=n)
clusters = clusters[:, :samples] clusters = clusters[:, :samples]
filter = np.max(np.mean(clusters, axis=0), axis=1) > threshold filter = np.max(np.mean(clusters, axis=0), axis=1) > threshold
encoder.predict(data[:samples][filter]) encoder.predict(data[:samples][filter])
print("{}/{} samples used ({}%); confidence threshold={}".format(sum(filter), print("{}/{} samples used ({}%); confidence threshold={}".format(sum(filter),
samples, samples,
sum(filter)/samples*100, sum(filter)/samples*100,
threshold)) threshold))
clusters = np.argmax(np.mean(clusters, axis=0), axis=1)[filter] clusters = np.argmax(np.mean(clusters, axis=0), axis=1)[filter]
rep = reducer.fit_transform(encoder.predict(data[:samples][filter])) rep = reducer.fit_transform(encoder.predict(data[:samples][filter]))
if n == 2: if n == 2:
df = pd.DataFrame({"encoding-1":rep[:,0],"encoding-2":rep[:,1],"clusters":["A"+str(i) for i in clusters]}) df = pd.DataFrame({"encoding-1":rep[:,0],"encoding-2":rep[:,1],"clusters":["A"+str(i) for i in clusters]})
enc = px.scatter(data_frame=df, x="encoding-1", y="encoding-2", enc = px.scatter(data_frame=df, x="encoding-1", y="encoding-2",
color="clusters", width=600, height=600, color="clusters", width=600, height=600,
color_discrete_sequence=px.colors.qualitative.T10) color_discrete_sequence=px.colors.qualitative.T10)
elif n == 3: elif n == 3:
df3d = pd.DataFrame({"encoding-1":rep[:,0],"encoding-2":rep[:,1],"encoding-3":rep[:,2], df3d = pd.DataFrame({"encoding-1":rep[:,0],"encoding-2":rep[:,1],"encoding-3":rep[:,2],
"clusters":["A"+str(i) for i in clusters]}) "clusters":["A"+str(i) for i in clusters]})
enc = px.scatter_3d(data_frame=df3d, x="encoding-1", y="encoding-2", z="encoding-3", enc = px.scatter_3d(data_frame=df3d, x="encoding-1", y="encoding-2", z="encoding-3",
color="clusters", width=600, height=600, color="clusters", width=600, height=600,
color_discrete_sequence=px.colors.qualitative.T10) color_discrete_sequence=px.colors.qualitative.T10)
return enc return enc
plot_encodings(data, 5000, 2, clusts, 0.5) plot_encodings(data, 5000, 2, clusts, 0.5)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Confidence per cluster # Confidence per cluster
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from collections import Counter from collections import Counter
Counter(clusters) Counter(clusters)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Confidence distribution per cluster # Confidence distribution per cluster
for cl in range(5): for cl in range(5):
cl_select = np.argmax(np.mean(clusts, axis=0), axis=1) == cl cl_select = np.argmax(np.mean(clusts, axis=0), axis=1) == cl
dt = np.mean(clusts[:,cl_select,cl], axis=0) dt = np.mean(clusts[:,cl_select,cl], axis=0)
sns.kdeplot(dt, shade=True, label=cl) sns.kdeplot(dt, shade=True, label=cl)
plt.xlabel('MC Dropout confidence') plt.xlabel('MC Dropout confidence')
plt.ylabel('Density') plt.ylabel('Density')
plt.show() plt.show()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def animated_cluster_heatmap(data, clust, clusters, threshold=0.75, samples=False): def animated_cluster_heatmap(data, clust, clusters, threshold=0.75, samples=False):
if not samples: if not samples:
samples = data.shape[0] samples = data.shape[0]
tpoints = data.shape[1] tpoints = data.shape[1]
bdparts = data.shape[2] // 2 bdparts = data.shape[2] // 2
cls = clusters[:,:samples,:] cls = clusters[:,:samples,:]
filt = np.max(np.mean(cls, axis=0), axis=1) > threshold filt = np.max(np.mean(cls, axis=0), axis=1) > threshold
cls = np.argmax(np.mean(cls, axis=0), axis=1)[filt] cls = np.argmax(np.mean(cls, axis=0), axis=1)[filt]
clust_series = data[:samples][filt][cls==clust] clust_series = data[:samples][filt][cls==clust]
rshape = clust_series.reshape(clust_series.shape[0]*clust_series.shape[1], rshape = clust_series.reshape(clust_series.shape[0]*clust_series.shape[1],
clust_series.shape[2]) clust_series.shape[2])
cluster_df = pd.DataFrame() cluster_df = pd.DataFrame()
cluster_df['x'] = rshape[:,[0,2,4,6,8,10]].flatten(order='F') cluster_df['x'] = rshape[:,[0,2,4,6,8,10]].flatten(order='F')
cluster_df['y'] = rshape[:,[1,3,5,7,9,11]].flatten(order='F') cluster_df['y'] = rshape[:,[1,3,5,7,9,11]].flatten(order='F')
cluster_df['bpart'] = np.tile(np.repeat(np.arange(bdparts), cluster_df['bpart'] = np.tile(np.repeat(np.arange(bdparts),
clust_series.shape[0]), tpoints) clust_series.shape[0]), tpoints)
cluster_df['frame'] = np.tile(np.repeat(np.arange(tpoints), cluster_df['frame'] = np.tile(np.repeat(np.arange(tpoints),
clust_series.shape[0]), bdparts) clust_series.shape[0]), bdparts)
fig = px.density_contour(data_frame=cluster_df, x='x', y='y', animation_frame='frame', fig = px.density_contour(data_frame=cluster_df, x='x', y='y', animation_frame='frame',
width=600, height=600, width=600, height=600,
color='bpart',color_discrete_sequence=px.colors.qualitative.T10) color='bpart',color_discrete_sequence=px.colors.qualitative.T10)
fig.update_traces(contours_coloring="fill", fig.update_traces(contours_coloring="fill",
contours_showlabels = True) contours_showlabels = True)
fig.update_xaxes(range=[-3, 3]) fig.update_xaxes(range=[-3, 3])
fig.update_yaxes(range=[-3, 3]) fig.update_yaxes(range=[-3, 3])
return fig return fig
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# animated_cluster_heatmap(pttest, 4, clusts, samples=10) # animated_cluster_heatmap(pttest, 4, clusts, samples=10)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Stability across runs # Stability across runs
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
weights = [i for i in os.listdir() if "GMVAE" in i and ".h5" in i] weights = [i for i in os.listdir() if "GMVAE" in i and ".h5" in i]
mult_clusters = np.zeros([len(weights), samples]) mult_clusters = np.zeros([len(weights), samples])
mean_conf = [] mean_conf = []
for k,i in tqdm(enumerate(sorted(weights))): for k,i in tqdm(enumerate(sorted(weights))):
print(i) print(i)
gmvaep.load_weights(i) gmvaep.load_weights(i)
if montecarlo: if montecarlo:
clusters = np.stack([grouper(data[:samples]) for sample in (tqdm(range(montecarlo)))]) clusters = np.stack([grouper(data[:samples]) for sample in (tqdm(range(montecarlo)))])
clusters = clusters.mean(axis=0) clusters = clusters.mean(axis=0)
mean_conf.append(clusters.max(axis=1)) mean_conf.append(clusters.max(axis=1))
clusters = np.argmax(clusters, axis=1) clusters = np.argmax(clusters, axis=1)
else: else:
clusters = grouper(data[:samples], training=False) clusters = grouper(data[:samples], training=False)
mean_conf.append(clusters.max(axis=1)) mean_conf.append(clusters.max(axis=1))
clusters = np.argmax(clusters, axis=1) clusters = np.argmax(clusters, axis=1)
mult_clusters[k] = clusters mult_clusters[k] = clusters
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
clusts.shape clusts.shape
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import pandas as pd import pandas as pd
from itertools import combinations from itertools import combinations
from sklearn.metrics import adjusted_rand_score from sklearn.metrics import adjusted_rand_score
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
mult_clusters mult_clusters
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
thr = 0.95 thr = 0.95
ari_dist = [] ari_dist = []
for i,k in enumerate(combinations(range(len(weights)),2)): for i,k in enumerate(combinations(range(len(weights)),2)):
filt = ((mean_conf[k[0]] > thr) & (mean_conf[k[1]]>thr)) filt = ((mean_conf[k[0]] > thr) & (mean_conf[k[1]]>thr))
ari = adjusted_rand_score(mult_clusters[k[0]][filt], ari = adjusted_rand_score(mult_clusters[k[0]][filt],
mult_clusters[k[1]][filt]) mult_clusters[k[1]][filt])
ari_dist.append(ari) ari_dist.append(ari)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
ari_dist ari_dist
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
random_ari = [] random_ari = []
for i in tqdm(range(6)): for i in tqdm(range(6)):
random_ari.append(adjusted_rand_score(np.random.uniform(0,6,50).astype(int), random_ari.append(adjusted_rand_score(np.random.uniform(0,6,50).astype(int),
np.random.uniform(0,6,50).astype(int))) np.random.uniform(0,6,50).astype(int)))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
sns.kdeplot(ari_dist, label="ARI gmvaep", shade=True) sns.kdeplot(ari_dist, label="ARI gmvaep", shade=True)
sns.kdeplot(random_ari, label="ARI random", shade=True) sns.kdeplot(random_ari, label="ARI random", shade=True)
plt.xlabel("Normalised Adjusted Rand Index") plt.xlabel("Normalised Adjusted Rand Index")
plt.ylabel("Density") plt.ylabel("Density")
plt.legend() plt.legend()
plt.show() plt.show()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Cluster differences across conditions # Cluster differences across conditions
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
DLCS1_coords = DLC_social_1_coords.get_coords(center="B_Center",polar=False, length='00:10:00', align='B_Nose') DLCS1_coords = DLC_social_1_coords.get_coords(center="B_Center",polar=False, length='00:10:00', align='B_Nose')
Treatment_coords = {} Treatment_coords = {}
for cond in Treatment_dict.keys(): for cond in Treatment_dict.keys():
Treatment_coords[cond] = DLCS1_coords.filter(Treatment_dict[cond]).preprocess(window_size=13, Treatment_coords[cond] = DLCS1_coords.filter(Treatment_dict[cond]).preprocess(window_size=13,
window_step=10, filter=None, scale='standard', align='center') window_step=10, filter=None, scale='standard', align='center')
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
%%time %%time
montecarlo = 10 montecarlo = 10
Predictions_per_cond = {} Predictions_per_cond = {}
Confidences_per_cond = {} Confidences_per_cond = {}
for cond in Treatment_dict.keys(): for cond in Treatment_dict.keys():
Predictions_per_cond[cond] = np.stack([grouper(Treatment_coords[cond] Predictions_per_cond[cond] = np.stack([grouper(Treatment_coords[cond]
) for sample in (tqdm(range(montecarlo)))]) ) for sample in (tqdm(range(montecarlo)))])
Confidences_per_cond[cond] = np.mean(Predictions_per_cond[cond], axis=0) Confidences_per_cond[cond] = np.mean(Predictions_per_cond[cond], axis=0)
Predictions_per_cond[cond] = np.argmax(Confidences_per_cond[cond], axis=1) Predictions_per_cond[cond] = np.argmax(Confidences_per_cond[cond], axis=1)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
Predictions_per_condition = {k:{cl:[] for cl in range(1,31)} for k in Treatment_dict.keys()} Predictions_per_condition = {k:{cl:[] for cl in range(1,31)} for k in Treatment_dict.keys()}
for k in Predictions_per_cond.values(): for k in Predictions_per_cond.values():
print(Counter(k)) print(Counter(k))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
for cond in Treatment_dict.keys(): for cond in Treatment_dict.keys():
start = 0 start = 0
for i,j in enumerate(DLCS1_coords.filter(Treatment_dict[cond]).values()): for i,j in enumerate(DLCS1_coords.filter(Treatment_dict[cond]).values()):
update = start + j.shape[0]//10 update = start + j.shape[0]//10
counter = Counter(Predictions_per_cond[cond][start:update]) counter = Counter(Predictions_per_cond[cond][start:update])
start += j.shape[0]//10 start += j.shape[0]//10
for num in counter.keys(): for num in counter.keys():
Predictions_per_condition[cond][num+1].append(counter[num+1]) Predictions_per_condition[cond][num+1].append(counter[num+1])
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
counts = [] counts = []
clusters = [] clusters = []
conditions = [] conditions = []
for cond,v in Predictions_per_condition.items(): for cond,v in Predictions_per_condition.items():
for cluster,i in v.items(): for cluster,i in v.items():
counts+=i counts+=i
clusters+=list(np.repeat(cluster, len(i))) clusters+=list(np.repeat(cluster, len(i)))
conditions+=list(np.repeat(cond, len(i))) conditions+=list(np.repeat(cond, len(i)))
Prediction_per_cond_df = pd.DataFrame({'condition':conditions, Prediction_per_cond_df = pd.DataFrame({'condition':conditions,
'cluster':clusters, 'cluster':clusters,
'count':counts}) 'count':counts})
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
px.box(data_frame=Prediction_per_cond_df, x='cluster', y='count', color='condition') px.box(data_frame=Prediction_per_cond_df, x='cluster', y='count', color='condition')
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# Others # Others
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
for i in range(5): for i in range(5):
print(Counter(labels[str(i)])) print(Counter(labels[str(i)]))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
adjusted_rand_score(labels[0], labels[3]) adjusted_rand_score(labels[0], labels[3])
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
sns.distplot(ari_dist) sns.distplot(ari_dist)
plt.xlabel("Adjusted Rand Index") plt.xlabel("Adjusted Rand Index")
plt.ylabel("Count") plt.ylabel("Count")
plt.show() plt.show()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import tensorflow as tf import tensorflow as tf
import tensorflow_probability as tfp import tensorflow_probability as tfp
tfd = tfp.distributions tfd = tfp.distributions
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from scipy.stats import entropy from scipy.stats import entropy
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
entropy(np.array([0.5,0,0.5,0])) entropy(np.array([0.5,0,0.5,0]))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
tfd.Categorical(np.array([0.5,0.5,0.5,0.5])).entropy() tfd.Categorical(np.array([0.5,0.5,0.5,0.5])).entropy()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
pk = np.array([0.5,0,0.5,0]) pk = np.array([0.5,0,0.5,0])
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
np.log(pk) np.log(pk)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
np.clip(np.log(pk), 0, 1) np.clip(np.log(pk), 0, 1)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
-np.sum(pk*np.array([-0.69314718, 0, -0.69314718, 0])) -np.sum(pk*np.array([-0.69314718, 0, -0.69314718, 0]))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import tensorflow.keras.backend as K import tensorflow.keras.backend as K
entropy = K.sum(tf.multiply(pk, tf.where(~tf.math.is_inf(K.log(pk)), K.log(pk), 0)), axis=0) entropy = K.sum(tf.multiply(pk, tf.where(~tf.math.is_inf(K.log(pk)), K.log(pk), 0)), axis=0)
entropy entropy
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
sns.distplot(np.max(clusts, axis=1)) sns.distplot(np.max(clusts, axis=1))
sns.distplot(clusts.reshape(clusts.shape[0] * clusts.shape[1])) sns.distplot(clusts.reshape(clusts.shape[0] * clusts.shape[1]))
plt.axvline(1/10) plt.axvline(1/10)
plt.show() plt.show()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
gauss_means = gmvaep.get_layer(name="dense_4").get_weights()[0][:32] gauss_means = gmvaep.get_layer(name="dense_4").get_weights()[0][:32]
gauss_variances = tf.keras.activations.softplus(gmvaep.get_layer(name="dense_4").get_weights()[0][32:]).numpy() gauss_variances = tf.keras.activations.softplus(gmvaep.get_layer(name="dense_4").get_weights()[0][32:]).numpy()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
gauss_means.shape == gauss_variances.shape gauss_means.shape == gauss_variances.shape
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
k=10 k=10
n=100 n=100
samples = [] samples = []
for i in range(k): for i in range(k):
samples.append(np.random.normal(gauss_means[:,i], gauss_variances[:,i], size=(100,32))) samples.append(np.random.normal(gauss_means[:,i], gauss_variances[:,i], size=(100,32)))
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from scipy.stats import ttest_ind from scipy.stats import ttest_ind
test_matrix = np.zeros([k,k]) test_matrix = np.zeros([k,k])
for i in range(k): for i in range(k):
for j in range(k): for j in range(k):
test_matrix[i][j] = np.mean(ttest_ind(samples[i], samples[j], equal_var=False)[1]) test_matrix[i][j] = np.mean(ttest_ind(samples[i], samples[j], equal_var=False)[1])
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
threshold = 0.55 threshold = 0.55
np.sum(test_matrix > threshold) np.sum(test_matrix > threshold)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Transition matrix # Transition matrix
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
Treatment_dict Treatment_dict
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Anomaly detection - the model was trained in the WT - NS mice alone # Anomaly detection - the model was trained in the WT - NS mice alone
gmvaep.load_weights("GMVAE_components=10_loss=ELBO_kl_warmup=20_mmd_warmup=5_20200721-043310_final_weights.h5") gmvaep.load_weights("GMVAE_components=10_loss=ELBO_kl_warmup=20_mmd_warmup=5_20200721-043310_final_weights.h5")
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
WT_NS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['WT+NS']}, typ="coords") WT_NS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['WT+NS']}, typ="coords")
WT_WS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['WT+CSDS']}, typ="coords") WT_WS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['WT+CSDS']}, typ="coords")
MU_NS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['NatCre+NS']}, typ="coords") MU_NS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['NatCre+NS']}, typ="coords")
MU_WS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['NatCre+CSDS']}, typ="coords") MU_WS = table_dict({k:v for k,v in mtest2.items() if k in Treatment_dict['NatCre+CSDS']}, typ="coords")
preps = [WT_NS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True), preps = [WT_NS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True),
WT_WS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True), WT_WS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True),
MU_NS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True), MU_NS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True),
MU_WS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True)] MU_WS.preprocess(window_size=11, window_step=10, filter="gaussian", sigma=55,shift=0, scale="standard", align=True)]
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
preds = [gmvaep.predict(i) for i in preps] preds = [gmvaep.predict(i) for i in preps]
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from sklearn.metrics import mean_absolute_error from sklearn.metrics import mean_absolute_error
reconst_error = {k:mean_absolute_error(preps[i].reshape(preps[i].shape[0]*preps[i].shape[1],12).T, reconst_error = {k:mean_absolute_error(preps[i].reshape(preps[i].shape[0]*preps[i].shape[1],12).T,
preds[i].reshape(preds[i].shape[0]*preds[i].shape[1],12).T, preds[i].reshape(preds[i].shape[0]*preds[i].shape[1],12).T,
multioutput='raw_values') for i,k in enumerate(Treatment_dict.keys())} multioutput='raw_values') for i,k in enumerate(Treatment_dict.keys())}
reconst_error reconst_error
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
reconst_df = pd.concat([pd.DataFrame(np.concatenate([np.repeat(k, len(v)).reshape(len(v),1), v.reshape(len(v),1)],axis=1)) for k,v in reconst_error.items()]) reconst_df = pd.concat([pd.DataFrame(np.concatenate([np.repeat(k, len(v)).reshape(len(v),1), v.reshape(len(v),1)],axis=1)) for k,v in reconst_error.items()])
reconst_df = reconst_df.astype({0:str,1:float}) reconst_df = reconst_df.astype({0:str,1:float})
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
sns.boxplot(data=reconst_df, x=0, y=1, orient='vertical') sns.boxplot(data=reconst_df, x=0, y=1, orient='vertical')
plt.ylabel('Mean Absolute Error') plt.ylabel('Mean Absolute Error')
plt.ylim(0,0.35) plt.ylim(0,0.35)
plt.show() plt.show()
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Check frame rates # Check frame rates
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
``` ```
......
...@@ -80,7 +80,7 @@ def test_get_callbacks( ...@@ -80,7 +80,7 @@ def test_get_callbacks(
elements=st.floats(min_value=0.0, max_value=1,), elements=st.floats(min_value=0.0, max_value=1,),
), ),
batch_size=st.integers(min_value=128, max_value=512), batch_size=st.integers(min_value=128, max_value=512),
hypermodel=st.one_of(st.just("S2SAE"), st.just("S2SGMVAE")), hypermodel=st.just("S2SGMVAE"),
k=st.integers(min_value=1, max_value=10), k=st.integers(min_value=1, max_value=10),
kl_wu=st.integers(min_value=0, max_value=10), kl_wu=st.integers(min_value=0, max_value=10),
loss=st.one_of(st.just("ELBO"), st.just("MMD")), loss=st.one_of(st.just("ELBO"), st.just("MMD")),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment