From 766c9002f39c450d82f1910d1b7a4262ccef4f55 Mon Sep 17 00:00:00 2001
From: lucas_miranda <lucasmiranda42@gmail.com>
Date: Fri, 29 May 2020 16:53:43 +0200
Subject: [PATCH] Added MinMaxScaler option to preprocessing

---
 main.ipynb           | 162 +++++++++++++++++++++++++++++++++++--------
 source/models.py     |   8 +--
 source/preprocess.py |   8 ++-
 3 files changed, 142 insertions(+), 36 deletions(-)

diff --git a/main.ipynb b/main.ipynb
index a827e06a..79673605 100644
--- a/main.ipynb
+++ b/main.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -15,7 +15,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -30,7 +30,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 3,
    "metadata": {
     "tags": [
      "parameters"
@@ -50,7 +50,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -60,7 +60,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -76,9 +76,18 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CPU times: user 2.72 s, sys: 814 ms, total: 3.53 s\n",
+      "Wall time: 1.3 s\n"
+     ]
+    }
+   ],
    "source": [
     "%%time\n",
     "DLC_social_1 = project(path=path,#Path where to find the required files\n",
@@ -104,9 +113,34 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Loading trajectories...\n",
+      "Smoothing trajectories...\n",
+      "Computing distances...\n",
+      "Computing angles...\n",
+      "Done!\n",
+      "Coordinates of 47 videos across 4 conditions\n",
+      "CPU times: user 9.14 s, sys: 558 ms, total: 9.69 s\n",
+      "Wall time: 13.1 s\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "source.preprocess.coordinates"
+      ]
+     },
+     "execution_count": 7,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "%%time\n",
     "DLC_social_1_coords = DLC_social_1.run(verbose=True)\n",
@@ -123,11 +157,30 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 8,
    "metadata": {
     "scrolled": true
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CPU times: user 812 ms, sys: 70.2 ms, total: 882 ms\n",
+      "Wall time: 853 ms\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "'coords'"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "%%time\n",
     "ptest = DLC_social_1_coords.get_coords(center=True, polar=False, speed=0, length='00:10:00')\n",
@@ -136,9 +189,28 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CPU times: user 553 ms, sys: 382 ms, total: 935 ms\n",
+      "Wall time: 935 ms\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "'dists'"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "%%time\n",
     "dtest = DLC_social_1_coords.get_distances(speed=0, length='00:10:00')\n",
@@ -147,9 +219,28 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "CPU times: user 139 ms, sys: 84.2 ms, total: 223 ms\n",
+      "Wall time: 223 ms\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "'angles'"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "%%time\n",
     "atest = DLC_social_1_coords.get_angles(degrees=True, speed=0, length='00:10:00')\n",
@@ -165,7 +256,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -174,7 +265,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 12,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -198,7 +289,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 13,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -207,7 +298,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 14,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -224,7 +315,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 15,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -235,7 +326,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 16,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -244,11 +335,22 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "pttest = mtest.preprocess(window_size=11, window_step=6, filter=None)\n",
+   "execution_count": 20,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(117507, 11, 28)"
+      ]
+     },
+     "execution_count": 20,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "pttest = mtest.preprocess(window_size=11, window_step=6, filter=None, standard_scaler=False)\n",
     "pttest.shape"
    ]
   },
diff --git a/source/models.py b/source/models.py
index 65d34904..cd4c68e1 100644
--- a/source/models.py
+++ b/source/models.py
@@ -6,7 +6,7 @@ from tensorflow.keras.initializers import he_uniform, Orthogonal
 from tensorflow.keras.layers import BatchNormalization, Bidirectional, Dense
 from tensorflow.keras.layers import Dropout, Lambda, LSTM
 from tensorflow.keras.layers import RepeatVector, TimeDistributed
-from tensorflow.keras.losses import Huber
+from tensorflow.keras.losses import BinaryCrossentropy, Huber
 from tensorflow.keras.optimizers import Adam
 from source.model_utils import *
 import tensorflow as tf
@@ -137,7 +137,7 @@ class SEQ_2_SEQ_AE:
         model = Sequential([encoder, decoder], name="SEQ_2_SEQ_AE")
 
         model.compile(
-            loss=Huber(reduction="sum", delta=100.0),
+            loss="binary_crossentropy",#Huber(reduction="sum", delta=100.0),
             optimizer=Adam(lr=self.learn_rate, clipvalue=0.5,),
             metrics=["mae"],
         )
@@ -309,7 +309,7 @@ class SEQ_2_SEQ_VAE:
             return self.input_shape[1:] * huber(x_, x_decoded_mean_)
 
         vae.compile(
-            loss=huber_loss,
+            loss="binary_crossentropy",#huber_loss,
             optimizer=Adam(lr=self.learn_rate,),
             metrics=["mae"],
             experimental_run_tf_function=False,
@@ -518,7 +518,7 @@ class SEQ_2_SEQ_VAEP:
             return self.input_shape[1:] * huber(x_, x_decoded_mean_)
 
         vaep.compile(
-            loss=huber_loss,
+            loss="binary_crossentropy",
             optimizer=Adam(lr=self.learn_rate,),
             metrics=["mae"],
             experimental_run_tf_function=False,
diff --git a/source/preprocess.py b/source/preprocess.py
index 0dffa671..881abd3f 100644
--- a/source/preprocess.py
+++ b/source/preprocess.py
@@ -493,8 +493,12 @@ class table_dict(dict):
                 X_train.reshape(-1, X_train.shape[-1])
             ).reshape(X_train.shape)
 
-            assert np.allclose(np.mean(X_train), 0)
-            assert np.allclose(np.std(X_train), 1)
+            if standard_scaler:
+                assert np.allclose(np.mean(X_train), 0)
+                assert np.allclose(np.std(X_train), 1)
+            else:
+                assert np.all(X_train >= 0)
+                assert np.all(X_train <= 1)
 
             if test_proportion:
                 X_test = scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(
-- 
GitLab