Commit 3b3bea35 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added extra branch to main autoencoder for rule_based prediction

parent ad1c6e9a
......@@ -1246,8 +1246,10 @@ class table_dict(dict):
X_test = deepof.utils.align_trajectories(X_test, align)
X_test = deepof.utils.rolling_window(X_test, window_size, window_step)
if self._propagate_labels:
y_test = y_test[::window_step][: X_test.shape[0]]
if self._propagate_labels or self._propagate_annotations:
y_test = deepof.utils.rolling_window(y_test, window_size, window_step)
y_test = y_test.mean(axis=1)
if align == "center":
X_test = deepof.utils.align_trajectories(X_test, align)
......
......@@ -267,10 +267,6 @@ window_step = args.window_step
if not train_path:
raise ValueError("Set a valid data path for the training to run")
if not val_num:
raise ValueError(
"Set a valid data path / validation number for the validation to run"
)
assert input_type in [
"coords",
......@@ -324,6 +320,9 @@ coords = project_coords.get_coords(
align=animal_id + undercond + "Spine_1",
align_inplace=True,
propagate_labels=(phenotype_prediction > 0),
propagate_annotations=(
False if not rule_based_prediction else project_coords.rule_based_annotation()
),
)
distances = project_coords.get_distances()
angles = project_coords.get_angles()
......@@ -363,7 +362,7 @@ X_train, y_train, X_val, y_val = batch_preprocess(input_dict_train[input_type])
print("Training set shape:", X_train.shape)
print("Validation set shape:", X_val.shape)
if phenotype_prediction > 0:
if any([phenotype_prediction, rule_based_prediction]):
print("Training set label shape:", y_train.shape)
print("Validation set label shape:", y_val.shape)
......@@ -388,7 +387,7 @@ if not tune:
output_path=output_path,
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_base_prediction,
rule_based_prediction=rule_based_prediction,
save_checkpoints=False,
save_weights=True,
variational=variational,
......
......@@ -355,6 +355,14 @@ def autoencoder_fitting(
metrics=metrics,
)
# Gets the number of rule-based features
try:
rule_based_features = (
y_train.shape[1] if not phenotype_prediction else y_train.shape[1] - 1
)
except IndexError:
rule_based_features = 0
# Build models
if not variational:
encoder, decoder, ae = deepof.models.SEQ_2_SEQ_AE(
......@@ -385,9 +393,7 @@ def autoencoder_fitting(
next_sequence_prediction=next_sequence_prediction,
phenotype_prediction=phenotype_prediction,
rule_based_prediction=rule_based_prediction,
rule_based_features=(
y_train.shape[1] if not phenotype_prediction else y_train.shape[1] - 1
),
rule_based_features=rule_based_features,
reg_cat_clusters=reg_cat_clusters,
reg_cluster_variance=reg_cluster_variance,
).build(
......@@ -443,16 +449,16 @@ def autoencoder_fitting(
Xvals, yvals = X_val[:-1], [X_val[:-1], X_val[1:]]
if phenotype_prediction > 0.0:
ys += [y_train[:, 0]]
yvals += [y_val[:, 0]]
ys += [y_train[-Xs.shape[0] :, 0]]
yvals += [y_val[-Xs.shape[0] :, 0]]
# Remove the used column (phenotype) from both y arrays
y_train = y_train[:, 1:]
y_val = y_val[:, 1:]
if rule_based_prediction > 0.0:
ys += [y_train]
yvals += [y_val]
ys += [y_train[-Xs.shape[0]:]]
yvals += [y_val[-Xs.shape[0]:]]
ae.fit(
x=Xs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment