diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..5c98b428844d9f7d529e2b6fb918d15bf072f3df
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,2 @@
+# Default ignored files
+/workspace.xml
\ No newline at end of file
diff --git a/.idea/.name b/.idea/.name
new file mode 100644
index 0000000000000000000000000000000000000000..d8773cf97151415186f7224b2ab4d3d5d2453759
--- /dev/null
+++ b/.idea/.name
@@ -0,0 +1 @@
+DeepOF
\ No newline at end of file
diff --git a/.idea/DeepOF.iml b/.idea/DeepOF.iml
new file mode 100644
index 0000000000000000000000000000000000000000..1ed70df98cb5201c967e8fa9db4c5e132eac1564
--- /dev/null
+++ b/.idea/DeepOF.iml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<module type="PYTHON_MODULE" version="4">
+  <component name="NewModuleRootManager">
+    <content url="file://$MODULE_DIR$" />
+    <orderEntry type="jdk" jdkName="Python 3.6 (Machine_Learning)" jdkType="Python SDK" />
+    <orderEntry type="sourceFolder" forTests="false" />
+  </component>
+</module>
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+<component name="InspectionProjectProfileManager">
+  <settings>
+    <option name="USE_PROJECT_PROFILE" value="false" />
+    <version value="1.0" />
+  </settings>
+</component>
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000000000000000000000000000000000000..a30251052c39ac3bbc31b7405b488dd6d2a52cee
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,20 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.6 (Machine_Learning)" project-jdk-type="Python SDK" />
+  <component name="PyCharmProfessionalAdvertiser">
+    <option name="shown" value="true" />
+  </component>
+  <component name="RMarkdownSettings">
+    <option name="renderProfiles">
+      <map>
+        <entry key="$PROJECT_DIR$/README.rmd">
+          <value>
+            <RMarkdownRenderProfile>
+              <option name="knitRootDirectory" value="$PROJECT_DIR$" />
+            </RMarkdownRenderProfile>
+          </value>
+        </entry>
+      </map>
+    </option>
+  </component>
+</project>
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000000000000000000000000000000000000..1aabd7c71488285c9e3483bd740e0d5a92872aaf
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="ProjectModuleManager">
+    <modules>
+      <module fileurl="file://$PROJECT_DIR$/.idea/DeepOF.iml" filepath="$PROJECT_DIR$/.idea/DeepOF.iml" />
+    </modules>
+  </component>
+</project>
\ No newline at end of file
diff --git a/.idea/rGraphicsSettings.xml b/.idea/rGraphicsSettings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..e7a7db1f406d7ceda5c3ebb91977c31d958db3af
--- /dev/null
+++ b/.idea/rGraphicsSettings.xml
@@ -0,0 +1,9 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="RGraphicsSettings">
+    <option name="height" value="1050" />
+    <option name="resolution" value="75" />
+    <option name="version" value="1" />
+    <option name="width" value="1680" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/.idea/rSettings.xml b/.idea/rSettings.xml
new file mode 100644
index 0000000000000000000000000000000000000000..6d7112b176c6187ac27c55e42282847020e8d81e
--- /dev/null
+++ b/.idea/rSettings.xml
@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="RSettings">
+    <option name="interpreterPath" value="/usr/local/bin/R" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/.idea/rpackages.xml b/.idea/rpackages.xml
new file mode 100644
index 0000000000000000000000000000000000000000..1709b2daf5012fca75a99a0416fe2c3f8dcee2c1
--- /dev/null
+++ b/.idea/rpackages.xml
@@ -0,0 +1,10 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="RPackageService">
+    <option name="enabledRepositoryUrls">
+      <list>
+        <option value="@CRAN@" />
+      </list>
+    </option>
+  </component>
+</project>
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000000000000000000000000000000000000..94a25f7f4cb416c083d265558da75d457237d671
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="$PROJECT_DIR$" vcs="Git" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/.ipynb_checkpoints/main-checkpoint.ipynb b/.ipynb_checkpoints/main-checkpoint.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..6d654ba312f53900a0bb1624860713bfd0af543a
--- /dev/null
+++ b/.ipynb_checkpoints/main-checkpoint.ipynb
@@ -0,0 +1,712 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%load_ext autoreload\n",
+    "%autoreload 2\n",
+    "\n",
+    "import warnings\n",
+    "warnings.filterwarnings(\"ignore\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#from source.utils import *\n",
+    "from source.preprocess import *\n",
+    "import pickle\n",
+    "import matplotlib.pyplot as plt\n",
+    "import pandas as pd\n",
+    "from collections import defaultdict\n",
+    "from tqdm import tqdm_notebook as tqdm"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Set up and design the project"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "with open('../../Desktop/DLC_social_1/DLC_social_1_exp_conditions.pickle', 'rb') as handle:\n",
+    "    Treatment_dict = pickle.load(handle)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#Which angles to compute?\n",
+    "bp_dict = {'B_Nose':['B_Left_ear','B_Right_ear'],\n",
+    "          'B_Left_ear':['B_Nose','B_Right_ear','B_Center','B_Left_flank'],\n",
+    "          'B_Right_ear':['B_Nose','B_Left_ear','B_Center','B_Right_flank'],\n",
+    "          'B_Center':['B_Left_ear','B_Right_ear','B_Left_flank','B_Right_flank','B_Tail_base'],\n",
+    "          'B_Left_flank':['B_Left_ear','B_Center','B_Tail_base'],\n",
+    "          'B_Right_flank':['B_Right_ear','B_Center','B_Tail_base'],\n",
+    "          'B_Tail_base':['B_Center','B_Left_flank','B_Right_flank']}"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CPU times: user 2.71 s, sys: 837 ms, total: 3.54 s\n",
+      "Wall time: 1.13 s\n"
+     ]
+    }
+   ],
+   "source": [
+    "%%time\n",
+    "DLC_social_1 = project(path='../../Desktop/DLC_social_1/',#Path where to find the required files\n",
+    "                   smooth_alpha=0.85,                    #Alpha value for exponentially weighted smoothing\n",
+    "                   distances=['B_Center','B_Nose','B_Left_ear','B_Right_ear','B_Left_flank',\n",
+    "                              'B_Right_flank','B_Tail_base'],\n",
+    "                   ego=False,\n",
+    "                   angles=True,\n",
+    "                   connectivity=bp_dict,\n",
+    "                   arena='circular',                  #Type of arena used in the experiments\n",
+    "                   arena_dims=[380],                  #Dimensions of the arena. Just one if it's circular\n",
+    "                   video_format='.mp4',\n",
+    "                   table_format='.h5',\n",
+    "                   exp_conditions=Treatment_dict)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Run project"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Loading trajectories...\n",
+      "Smoothing trajectories...\n",
+      "Computing distances...\n",
+      "Computing angles...\n",
+      "Done!\n",
+      "Coordinates of 47 videos across 4 conditions\n",
+      "CPU times: user 9.19 s, sys: 562 ms, total: 9.75 s\n",
+      "Wall time: 10 s\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "source.preprocess.coordinates"
+      ]
+     },
+     "execution_count": 7,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "%%time\n",
+    "DLC_social_1_coords = DLC_social_1.run(verbose=True)\n",
+    "print(DLC_social_1_coords)\n",
+    "type(DLC_social_1_coords)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Generate coords"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CPU times: user 832 ms, sys: 72.8 ms, total: 905 ms\n",
+      "Wall time: 857 ms\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "'coords'"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "%%time\n",
+    "ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')\n",
+    "ptest._type"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CPU times: user 525 ms, sys: 365 ms, total: 890 ms\n",
+      "Wall time: 889 ms\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "'dists'"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "%%time\n",
+    "dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')\n",
+    "dtest._type"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CPU times: user 140 ms, sys: 80.3 ms, total: 221 ms\n",
+      "Wall time: 220 ms\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "'angles'"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "%%time\n",
+    "atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')\n",
+    "atest._type"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Visualization playground"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "ptest.plot_heatmaps(['B_Center', 'W_Center'], i=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#Plot animation of trajectory over time with different smoothings\n",
+    "plt.plot(ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['x'],\n",
+    "         ptest['Day2Test13DLC']['B_Center'].iloc[:5000]['y'], label='alpha=0.85')\n",
+    "\n",
+    "plt.xlabel('x')\n",
+    "plt.ylabel('y')\n",
+    "plt.title('Mouse Center Trajectory using different exponential smoothings')\n",
+    "plt.legend()\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Dimensionality reduction playground"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pca = ptest.pca(4, 1000)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "plt.scatter(*pca[0].T)\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Preprocessing playground"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "mtest = merge_tables(DLC_social_1_coords.get_coords(center=True, polar=True, length='00:10:00'))#,\n",
+    "#                      DLC_social_1_coords.get_distances(speed=0, length='00:10:00'),\n",
+    "#                      DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pptest = mtest.preprocess(window_size=51, filter='gaussian', sigma=10, shift=20)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pttest = mtest.preprocess(window_size=51, filter=None)\n",
+    "pttest.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "plt.plot(pttest[2,:,2], label='normal')\n",
+    "plt.plot(pptest[2,:,2], label='gaussian')\n",
+    "plt.legend()\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Trained models playground"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Seq 2 seq Variational Auto Encoder"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pttest = pttest[:1000]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "CONV_filters = 64\n",
+    "LSTM_units_1 = 128\n",
+    "LSTM_units_2 = 64\n",
+    "DENSE_1 = 64\n",
+    "DENSE_2 = 32\n",
+    "ENCODING = 20\n",
+    "DROPOUT_RATE = 0.2\n",
+    "\n",
+    "original_dim = pttest.shape[1:]\n",
+    "batch_size = 256"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from source.hypermodels import *\n",
+    "import tensorflow as tf\n",
+    "from tensorflow.keras import Input, Model\n",
+    "from tensorflow.keras.layers import Dense, Lambda, Bidirectional, LSTM\n",
+    "from tensorflow.keras import backend as K\n",
+    "K.clear_session()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class KLDivergenceLayer(Layer):\n",
+    "\n",
+    "    \"\"\" Identity transform layer that adds KL divergence\n",
+    "    to the final model loss.\n",
+    "    \"\"\"\n",
+    "\n",
+    "    def __init__(self, *args, **kwargs):\n",
+    "        self.is_placeholder = True\n",
+    "        super(KLDivergenceLayer, self).__init__(*args, **kwargs)\n",
+    "\n",
+    "    def call(self, inputs):\n",
+    "\n",
+    "        mu, log_var = inputs\n",
+    "\n",
+    "        kl_batch = - .5 * K.sum(1 + log_var -\n",
+    "                                K.square(mu) -\n",
+    "                                K.exp(log_var), axis=-1)\n",
+    "\n",
+    "        self.add_loss(K.mean(kl_batch), inputs=inputs)\n",
+    "\n",
+    "        return inputs"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class MMDiscrepancyLayer(Layer):\n",
+    "\n",
+    "    \"\"\" Identity transform layer that adds MM discrepancy\n",
+    "    to the final model loss.\n",
+    "    \"\"\"\n",
+    "\n",
+    "    def __init__(self, *args, **kwargs):\n",
+    "        self.is_placeholder = True\n",
+    "        super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)\n",
+    "\n",
+    "    def call(self, z):\n",
+    "        \n",
+    "        true_samples = K.random_normal(K.shape(z), mean=0., stddev=2./K.cast_to_floatx(K.shape(z)[1]))        \n",
+    "        mmd_batch = compute_mmd(z, true_samples)\n",
+    "        \n",
+    "        self.add_loss(K.mean(mmd_batch), inputs=z)\n",
+    "\n",
+    "        return z"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Encoder Layers\n",
+    "Model_E0 = tf.keras.layers.Conv1D(\n",
+    "    filters=CONV_filters,\n",
+    "    kernel_size=5,\n",
+    "    strides=1,\n",
+    "    padding=\"causal\",\n",
+    "    activation=\"relu\",\n",
+    ")\n",
+    "Model_E1 = Bidirectional(\n",
+    "    LSTM(\n",
+    "        LSTM_units_1,\n",
+    "        activation=\"tanh\",\n",
+    "        return_sequences=True,\n",
+    "        kernel_constraint=UnitNorm(axis=0),\n",
+    "    )\n",
+    ")\n",
+    "Model_E2 = Bidirectional(\n",
+    "    LSTM(\n",
+    "        LSTM_units_2,\n",
+    "        activation=\"tanh\",\n",
+    "        return_sequences=False,\n",
+    "        kernel_constraint=UnitNorm(axis=0),\n",
+    "    )\n",
+    ")\n",
+    "Model_E3 = Dense(DENSE_1, activation=\"relu\", kernel_constraint=UnitNorm(axis=0))\n",
+    "Model_E4 = Dense(DENSE_2, activation=\"relu\", kernel_constraint=UnitNorm(axis=0))\n",
+    "Model_E5 = Dense(\n",
+    "            ENCODING,\n",
+    "            activation=\"relu\",\n",
+    "            kernel_constraint=UnitNorm(axis=1),\n",
+    "            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),\n",
+    "        )\n",
+    "\n",
+    "# Decoder layers\n",
+    "Model_D4 = Bidirectional(\n",
+    "    LSTM(\n",
+    "        LSTM_units_1,\n",
+    "        activation=\"tanh\",\n",
+    "        return_sequences=True,\n",
+    "        kernel_constraint=UnitNorm(axis=1),\n",
+    "    )\n",
+    ")\n",
+    "Model_D5 = Bidirectional(\n",
+    "    LSTM(\n",
+    "        LSTM_units_1,\n",
+    "        activation=\"sigmoid\",\n",
+    "        return_sequences=True,\n",
+    "        kernel_constraint=UnitNorm(axis=1),\n",
+    "    )\n",
+    ")\n",
+    "\n",
+    "# Define and instanciate encoder\n",
+    "x = Input(shape=original_dim)\n",
+    "encoder = Model_E0(x)\n",
+    "encoder = Model_E1(encoder)\n",
+    "encoder = Model_E2(encoder)\n",
+    "encoder = Model_E3(encoder)\n",
+    "encoder = Dropout(DROPOUT_RATE)(encoder)\n",
+    "encoder = Model_E4(encoder)\n",
+    "encoder = Model_E5(encoder)\n",
+    "z_mean = Dense(ENCODING)(encoder)\n",
+    "z_log_sigma = Dense(ENCODING)(encoder)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def sampling(args, epsilon_std=1.):\n",
+    "    z_mean, z_log_sigma = args\n",
+    "    epsilon = K.random_normal(shape=K.shape(z_mean),\n",
+    "                              mean=0., stddev=epsilon_std)\n",
+    "    return z_mean + K.exp(z_log_sigma) * epsilon\n",
+    "\n",
+    "# note that \"output_shape\" isn't necessary with the TensorFlow backend\n",
+    "# so you could write `Lambda(sampling)([z_mean, z_log_sigma])`\n",
+    "\n",
+    "z_mean, z_log_sigma = KLDivergenceLayer()([z_mean, z_log_sigma])\n",
+    "z = Lambda(sampling)([z_mean, z_log_sigma])\n",
+    "z = MMDiscrepancyLayer()(z)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Define and instanciate decoder\n",
+    "decoder = DenseTranspose(Model_E5, activation=\"relu\", output_dim=ENCODING)(z)\n",
+    "decoder = DenseTranspose(Model_E4, activation=\"relu\", output_dim=DENSE_2)(decoder)\n",
+    "decoder = DenseTranspose(Model_E3, activation=\"relu\", output_dim=DENSE_1)(decoder)\n",
+    "decoder = RepeatVector(pttest.shape[1])(decoder)\n",
+    "decoder = Model_D4(decoder)\n",
+    "decoder = Model_D5(decoder)\n",
+    "x_decoded_mean = TimeDistributed(Dense(original_dim[1]))(decoder)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# end-to-end autoencoder\n",
+    "vae = Model(x, x_decoded_mean)\n",
+    "\n",
+    "# encoder, from inputs to latent space\n",
+    "encoder = Model(x, z_mean)\n",
+    "\n",
+    "# generator, from latent space to reconstructed inputs\n",
+    "decoder_input = Input(shape=(ENCODING,))\n",
+    "decoder = DenseTranspose(Model_E5, activation=\"relu\", output_dim=ENCODING)(decoder_input)\n",
+    "decoder = DenseTranspose(Model_E4, activation=\"relu\", output_dim=DENSE_2)(decoder)\n",
+    "decoder = DenseTranspose(Model_E3, activation=\"relu\", output_dim=DENSE_1)(decoder)\n",
+    "decoder = RepeatVector(pttest.shape[1])(decoder)\n",
+    "decoder = Model_D4(decoder)\n",
+    "decoder = Model_D5(decoder)\n",
+    "x_decoded_mean = TimeDistributed(Dense(original_dim[1]))(decoder)\n",
+    "generator = Model(decoder_input, x_decoded_mean)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vae.summary()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "tf.keras.utils.plot_model(vae, show_shapes=True, show_layer_names=False)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def compute_kernel(x, y):\n",
+    "    x_size = K.shape(x)[0]\n",
+    "    y_size = K.shape(y)[0]\n",
+    "    dim = K.shape(x)[1]\n",
+    "    tiled_x = K.tile(K.reshape(x, K.stack([x_size, 1, dim])), K.stack([1, y_size, 1]))\n",
+    "    tiled_y = K.tile(K.reshape(y, K.stack([1, y_size, dim])), K.stack([x_size, 1, 1]))\n",
+    "    return K.exp(-tf.reduce_mean(K.square(tiled_x - tiled_y), axis=2) / K.cast(dim, tf.float32))\n",
+    "\n",
+    "def compute_mmd(x, y):\n",
+    "    x_kernel = compute_kernel(x, x)\n",
+    "    y_kernel = compute_kernel(y, y)\n",
+    "    xy_kernel = compute_kernel(x, y)\n",
+    "    return tf.reduce_mean(x_kernel) + tf.reduce_mean(y_kernel) - 2 * tf.reduce_mean(xy_kernel)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from tensorflow.keras.losses import Huber\n",
+    "\n",
+    "def huber_loss(x, x_decoded_mean):\n",
+    "    huber_loss = Huber(reduction=\"sum\", delta=100.0)\n",
+    "    return original_dim * huber_loss(x, x_decoded_mean)\n",
+    "\n",
+    "vae.compile(optimizer='adam', loss=huber_loss, experimental_run_tf_function=False, metrics=['mae'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "#tf.config.experimental_run_functions_eagerly(False)\n",
+    "ptrain = pttest[np.random.choice(pttest.shape[0], 1000, replace=False), :, :]\n",
+    "history = vae.fit(ptrain, ptrain, epochs=50, batch_size=batch_size, verbose=1)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": true
+   },
+   "outputs": [],
+   "source": [
+    "#plt.plot(history.history['mae'], label='Huber + MMD mae')\n",
+    "plt.plot(history.history['mae'], label='Huber + KL mae')\n",
+    "\n",
+    "plt.legend()\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#Huber loss + MMD/ELBO in training data\n",
+    "plt.plot(pttest[:2000,0,0], label='data')\n",
+    "plt.plot(vae.predict(pttest[:2000])[:,0,0], label='MMD reconstruction')\n",
+    "\n",
+    "plt.legend()\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.10"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/source/model_utils.py b/source/model_utils.py
index 3ce8e15db7eb2756af060716383d730adeb91d63..03bb1bd74f9749e00430975fea1091fc2cb7c085 100644
--- a/source/model_utils.py
+++ b/source/model_utils.py
@@ -143,10 +143,8 @@ class MMDiscrepancyLayer(Layer):
         super(MMDiscrepancyLayer, self).__init__(*args, **kwargs)
 
     def call(self, z, **kwargs):
-        true_samples = K.random_normal(
-            K.shape(z), mean=0.0, stddev=2.0 / K.cast_to_floatx(K.shape(z)[1])
-        )
-        mmd_batch = compute_mmd(z, true_samples)
+        true_samples = K.random_normal(K.shape(z))
+        mmd_batch = compute_mmd(true_samples, z)
 
         self.add_loss(K.mean(mmd_batch), inputs=z)
 
diff --git a/source/models.py b/source/models.py
index dcb21880d5a770097ef5a6a89b8418dc66e0d21f..40115979f8f44f2b40d2b9be13e523be8e624585 100644
--- a/source/models.py
+++ b/source/models.py
@@ -21,7 +21,7 @@ class SEQ_2_SEQ_AE:
         DENSE_2,
         DROPOUT_RATE,
         ENCODING,
-        learn_rate
+        learn_rate,
     ):
         self.input_shape = input_shape
         self.CONV_filters = CONV_filters
@@ -59,8 +59,12 @@ class SEQ_2_SEQ_AE:
                 kernel_constraint=UnitNorm(axis=0),
             )
         )
-        Model_E3 = Dense(self.DENSE_1, activation="relu", kernel_constraint=UnitNorm(axis=0))
-        Model_E4 = Dense(self.DENSE_2, activation="relu", kernel_constraint=UnitNorm(axis=0))
+        Model_E3 = Dense(
+            self.DENSE_1, activation="relu", kernel_constraint=UnitNorm(axis=0)
+        )
+        Model_E4 = Dense(
+            self.DENSE_2, activation="relu", kernel_constraint=UnitNorm(axis=0)
+        )
         Model_E5 = Dense(
             self.ENCODING,
             activation="relu",
@@ -114,10 +118,7 @@ class SEQ_2_SEQ_AE:
 
         model.compile(
             loss=Huber(reduction="sum", delta=100.0),
-            optimizer=Adam(
-                lr=self.learn_rate,
-                clipvalue=0.5,
-            ),
+            optimizer=Adam(lr=self.learn_rate, clipvalue=0.5,),
             metrics=["mae"],
         )
 
@@ -125,12 +126,141 @@ class SEQ_2_SEQ_AE:
 
 
 class SEQ_2_SEQ_VAE:
-    pass
+    def __init__(
+        self,
+        input_shape,
+        CONV_filters,
+        LSTM_units_1,
+        LSTM_units_2,
+        DENSE_2,
+        DROPOUT_RATE,
+        ENCODING,
+        learn_rate,
+    ):
+        self.input_shape = input_shape
+        self.CONV_filters = CONV_filters
+        self.LSTM_units_1 = LSTM_units_1
+        self.LSTM_units_2 = LSTM_units_2
+        self.DENSE_1 = LSTM_units_2
+        self.DENSE_2 = DENSE_2
+        self.DROPOUT_RATE = DROPOUT_RATE
+        self.ENCODING = ENCODING
+        self.learn_rate = learn_rate
+
+    def build(self):
+        # Encoder Layers
+        Model_E0 = tf.keras.layers.Conv1D(
+            filters=CONV_filters,
+            kernel_size=5,
+            strides=1,
+            padding="causal",
+            activation="relu",
+        )
+        Model_E1 = Bidirectional(
+            LSTM(
+                LSTM_units_1,
+                activation="tanh",
+                return_sequences=True,
+                kernel_constraint=UnitNorm(axis=0),
+            )
+        )
+        Model_E2 = Bidirectional(
+            LSTM(
+                LSTM_units_2,
+                activation="tanh",
+                return_sequences=False,
+                kernel_constraint=UnitNorm(axis=0),
+            )
+        )
+        Model_E3 = Dense(DENSE_1, activation="relu", kernel_constraint=UnitNorm(axis=0))
+        Model_E4 = Dense(DENSE_2, activation="relu", kernel_constraint=UnitNorm(axis=0))
+        Model_E5 = Dense(
+            ENCODING,
+            activation="relu",
+            kernel_constraint=UnitNorm(axis=1),
+            activity_regularizer=UncorrelatedFeaturesConstraint(3, weightage=1.0),
+        )
+
+        # Decoder layers
+        Model_D4 = Bidirectional(
+            LSTM(
+                LSTM_units_1,
+                activation="tanh",
+                return_sequences=True,
+                kernel_constraint=UnitNorm(axis=1),
+            )
+        )
+        Model_D5 = Bidirectional(
+            LSTM(
+                LSTM_units_1,
+                activation="sigmoid",
+                return_sequences=True,
+                kernel_constraint=UnitNorm(axis=1),
+            )
+        )
+
+        # Define and instanciate encoder
+        x = Input(shape=self.input_shape[1:])
+        encoder = Model_E0(x)
+        encoder = Model_E1(encoder)
+        encoder = Model_E2(encoder)
+        encoder = Model_E3(encoder)
+        encoder = Dropout(DROPOUT_RATE)(encoder)
+        encoder = Model_E4(encoder)
+        encoder = Model_E5(encoder)
+
+        z_mean = Dense(ENCODING)(encoder)
+        z_log_sigma = Dense(ENCODING)(encoder)
+
+        if "ELBO" in self.loss:
+            z_mean, z_log_sigma = KLDivergenceLayer()([z_mean, z_log_sigma])
+
+        z = Lambda(sampling)([z_mean, z_log_sigma])
+
+        if "MMD" in self.loss:
+            z = MMDiscrepancyLayer()(z)
+
+        # Define and instanciate decoder
+        decoder = DenseTranspose(Model_E5, activation="relu", output_dim=ENCODING)(z)
+        decoder = DenseTranspose(Model_E4, activation="relu", output_dim=DENSE_2)(
+            decoder
+        )
+        decoder = DenseTranspose(Model_E3, activation="relu", output_dim=DENSE_1)(
+            decoder
+        )
+        decoder = RepeatVector(self.input_shape[1])(decoder)
+        decoder = Model_D4(decoder)
+        decoder = Model_D5(decoder)
+        x_decoded_mean = TimeDistributed(Dense(self.input_shape[2]))(decoder)
+
+        # end-to-end autoencoder
+        vae = Model(x, x_decoded_mean)
+
+        def huber_loss(x, x_decoded_mean):
+            huber_loss = Huber(reduction="sum", delta=100.0)
+            return self.input_shape[1:] * huber_loss(x, x_decoded_mean)
+
+        vae.compile(
+            loss=huber_loss,
+            optimizer=Adam(
+                lr=hp.Float(
+                    "learning_rate",
+                    min_value=1e-4,
+                    max_value=1e-2,
+                    sampling="LOG",
+                    default=1e-3,
+                ),
+            ),
+            metrics=["mae"],
+            experimental_run_tf_function=False,
+        )
+
+        return encoder, generator, vae
 
 
-class SEQ_2_SEQ_MVAE():
+class SEQ_2_SEQ_MVAE:
     pass
 
 
-class SEQ_2_SEQ_MMVAE():
+class SEQ_2_SEQ_MMVAE:
     pass