Commit cf0b8980 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added _scaler method to table_dict, to retrieve the scaler used to preprocess data

parent fca79d55
......@@ -1158,6 +1158,7 @@ class table_dict(dict):
if scale == "standard":
self._scaler = StandardScaler()
elif scale == "minmax":
self._scaler = MinMaxScaler()
else:
......@@ -1165,18 +1166,20 @@ class table_dict(dict):
"Invalid scaler. Select one of standard, minmax or None"
) # pragma: no cover
X_train = self._scaler.fit_transform(
X_train.reshape(-1, X_train.shape[-1])
).reshape(X_train.shape)
X_train_flat = X_train.reshape(-1, X_train.shape[-1])
self._scaler.fit(X_train_flat)
X_train = self._scaler.transform(X_train_flat).reshape(X_train.shape)
if scale == "standard":
assert np.allclose(np.nan_to_num(np.mean(X_train), nan=0), 0)
assert np.allclose(np.nan_to_num(np.std(X_train), nan=1), 1)
assert np.all(np.nan_to_num(np.mean(X_train), nan=0) < 0.1)
assert np.all(np.nan_to_num(np.std(X_train), nan=1) > 0.9)
if test_videos:
X_test = self._scaler.transform(X_test.reshape(-1, X_test.shape[-1])).reshape(
X_test.shape
)
X_test = self._scaler.transform(
X_test.reshape(-1, X_test.shape[-1])
).reshape(X_test.shape)
if verbose:
print("Done!")
......
......@@ -552,7 +552,7 @@ class SEQ_2_SEQ_GMVAE:
encoder = BatchNormalization()(encoder)
encoder = Dropout(self.DROPOUT_RATE)(encoder)
encoder = Sequential(Model_E4)(encoder)
# encoder = BatchNormalization()(encoder)
encoder = BatchNormalization()(encoder)
# encoding_shuffle = deepof.model_utils.MCDropout(self.DROPOUT_RATE)(encoder)
z_cat = Dense(
......@@ -665,7 +665,7 @@ class SEQ_2_SEQ_GMVAE:
generator = Model_D4(generator)
generator = Model_B3(generator)
generator = Model_D5(generator)
# generator = Model_B4(generator)
generator = Model_B4(generator)
generator = Dense(tfpl.IndependentNormal.params_size(input_shape[2:]))(
generator
)
......
%% Cell type:code id: tags:
``` python
import os
os.chdir(os.path.dirname("../"))
```
%% Cell type:code id: tags:
``` python
import warnings
warnings.filterwarnings("ignore")
```
%% Cell type:markdown id: tags:
......@@ -156,10 +158,11 @@
%% Cell type:code id: tags:
``` python
random_exp = np.random.choice(list(position_coords.keys()), 1)[0]
@interact(time_slider=(0.0, 15000, 25), length_slider=(10, 100, 5))
def plot_mice_across_time(time_slider, length_slider):
plt.figure(figsize=(10, 10))
......@@ -200,18 +203,20 @@
%% Cell type:code id: tags:
``` python
# Auxiliary animation functions
def plot_mouse_graph(instant_x, instant_y, ax, edges):
"""Generates a graph plot of the mouse"""
plots = []
for edge in edges:
(temp_plot,) = ax.plot(
[float(instant_x[edge[0]]), float(instant_x[edge[1]])],
[float(instant_y[edge[0]]), float(instant_y[edge[1]])],
color="#006699",
linewidth=2.0,
)
plots.append(temp_plot)
return plots
......@@ -227,24 +232,29 @@
%% Cell type:code id: tags:
``` python
random_exp = np.random.choice(list(position_coords.keys()), 1)[0]
print(random_exp)
@interact(time_slider=(0.0, 15000, 25), length_slider=(10, 100, 5))
def animate_mice_across_time(time_slider, length_slider):
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
edges = deepof.utils.connect_mouse_topview()
for bpart in exclude_bodyparts:
edges.remove_node(bpart)
if bpart:
edges.remove_node(bpart)
for limb in ["Left_fhip", "Right_fhip", "Left_bhip", "Right_bhip"]:
edges.remove_edge("Center", limb)
if ("Tail_base", limb) in list(edges.edges()):
edges.remove_edge("Tail_base", limb)
edges = edges.edges()
data = position_coords[random_exp].loc[
time_slider : time_slider + length_slider - 1, :
]
......@@ -272,11 +282,11 @@
animation = FuncAnimation(
fig, func=animation_frame, frames=length_slider, interval=100,
)
ax.set_title("Positions across time for centered data")
ax.set_ylim(-100, 60)
ax.set_ylim(-90, 60)
ax.set_xlim(-60, 60)
ax.set_xlabel("x")
ax.set_ylabel("y")
video = animation.to_html5_video()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment