Commit be362675 authored by lucas_miranda's avatar lucas_miranda
Browse files

Added video generation options to rule_based_annotation

parent 76333c27
Pipeline #83264 passed with stage
in 18 minutes and 7 seconds
......@@ -637,7 +637,9 @@ class coordinates:
return self._arena, self._arena_dims, self._scales
# noinspection PyDefaultArgument
def rule_based_annotation(self, hparams: Dict = {}) -> Table_dict:
def rule_based_annotation(
self, hparams: Dict = {}, video_output: bool = False, frame_limit: int = np.inf
) -> Table_dict:
"""Annotates coordinates using a simple rule-based pipeline"""
tag_dict = {}
......@@ -648,9 +650,33 @@ class coordinates:
self,
idx,
recog_limit=1,
path=os.path.join(self._path,"Videos"),
path=os.path.join(self._path, "Videos"),
hparams=hparams,
)
if video_output: # pragma: no cover
if type(video_output) == list:
vid_idxs = video_output
elif video_output == "all":
vid_idxs = list(self._tables.keys())
else:
raise AttributeError(
"Video output must be either 'all' or a list with the names of the videos to render"
)
for idx in vid_idxs:
deepof.pose_utils.rule_based_video(
self,
list(self._tables.keys()),
self._videos,
list(self._tables.keys()).index(idx),
tag_dict[idx],
frame_limit=frame_limit,
recog_limit=1,
path=os.path.join(self._path, "Videos"),
hparams=hparams,
)
return table_dict(
tag_dict, typ="rule-based", arena=self._arena, arena_dims=self._arena_dims
)
......@@ -666,7 +692,7 @@ class table_dict(dict):
def __init__(
self,
tabs: Coordinates,
tabs: Dict,
typ: str,
arena: str = None,
arena_dims: np.array = None,
......
......@@ -96,14 +96,22 @@ class SEQ_2_SEQ_AE(HyperModel):
ENCODING,
activation="relu",
kernel_constraint=UnitNorm(axis=1),
activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(3, weightage=1.0),
activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(
3, weightage=1.0
),
kernel_initializer=Orthogonal(),
)
# Decoder layers
Model_D0 = deepof.model_utils.DenseTranspose(Model_E5, activation="relu", output_dim=ENCODING, )
Model_D1 = deepof.model_utils.DenseTranspose(Model_E4, activation="relu", output_dim=DENSE_2, )
Model_D2 = deepof.model_utils.DenseTranspose(Model_E3, activation="relu", output_dim=DENSE_1, )
Model_D0 = deepof.model_utils.DenseTranspose(
Model_E5, activation="relu", output_dim=ENCODING,
)
Model_D1 = deepof.model_utils.DenseTranspose(
Model_E4, activation="relu", output_dim=DENSE_2,
)
Model_D2 = deepof.model_utils.DenseTranspose(
Model_E3, activation="relu", output_dim=DENSE_1,
)
Model_D3 = RepeatVector(self.input_shape[1])
Model_D4 = Bidirectional(
LSTM(
......@@ -319,7 +327,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
z_cat = Dense(self.number_of_components, activation="softmax",)(encoder)
z_gauss = Dense(
deepof.model_utils.tfpl.IndependentNormal.params_size(ENCODING * self.number_of_components),
deepof.model_utils.tfpl.IndependentNormal.params_size(
ENCODING * self.number_of_components
),
activation=None,
)(encoder)
......@@ -339,7 +349,7 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
z_gauss = Reshape([2 * ENCODING, self.number_of_components])(z_gauss)
z = deepof.model_utils.tfpl.DistributionLambda(
lambda gauss: deepof.model_utils.tfd.mixture.Mixture(
cat=deepof.model_utils.tfd.categorical.Categorical(probs=gauss[0], ),
cat=deepof.model_utils.tfd.categorical.Categorical(probs=gauss[0],),
components=[
deepof.model_utils.tfd.Independent(
deepof.model_utils.tfd.Normal(
......@@ -351,7 +361,9 @@ class SEQ_2_SEQ_GMVAE(HyperModel):
for k in range(self.number_of_components)
],
),
activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(3, weightage=1.0),
activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(
3, weightage=1.0
),
)([z_cat, z_gauss])
if "ELBO" in self.loss:
......
......@@ -120,14 +120,22 @@ class SEQ_2_SEQ_AE:
self.ENCODING,
activation="elu",
kernel_constraint=UnitNorm(axis=1),
activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(2, weightage=1.0),
activity_regularizer=deepof.model_utils.uncorrelated_features_constraint(
2, weightage=1.0
),
kernel_initializer=Orthogonal(),
)
# Decoder layers
Model_D0 = deepof.model_utils.DenseTranspose(Model_E5, activation="elu", output_dim=self.ENCODING, )
Model_D1 = deepof.model_utils.DenseTranspose(Model_E4, activation="elu", output_dim=self.DENSE_2, )
Model_D2 = deepof.model_utils.DenseTranspose(Model_E3, activation="elu", output_dim=self.DENSE_1, )
Model_D0 = deepof.model_utils.DenseTranspose(
Model_E5, activation="elu", output_dim=self.ENCODING,
)
Model_D1 = deepof.model_utils.DenseTranspose(
Model_E4, activation="elu", output_dim=self.DENSE_2,
)
Model_D2 = deepof.model_utils.DenseTranspose(
Model_E3, activation="elu", output_dim=self.DENSE_1,
)
Model_D3 = RepeatVector(input_shape[1])
Model_D4 = Bidirectional(
LSTM(
......@@ -298,7 +306,7 @@ class SEQ_2_SEQ_GMVAE:
),
components=[
deepof.model_utils.tfd.Independent(
deepof.model_utils.tfd.Normal(loc=init_means[k], scale=1, ),
deepof.model_utils.tfd.Normal(loc=init_means[k], scale=1,),
reinterpreted_batch_ndims=1,
)
for k in range(self.number_of_components)
......@@ -527,12 +535,12 @@ class SEQ_2_SEQ_GMVAE:
z = deepof.model_utils.tfpl.DistributionLambda(
lambda gauss: deepof.model_utils.tfd.mixture.Mixture(
cat=deepof.model_utils.tfd.categorical.Categorical(probs=gauss[0], ),
cat=deepof.model_utils.tfd.categorical.Categorical(probs=gauss[0],),
components=[
deepof.model_utils.tfd.Independent(
deepof.model_utils.tfd.Normal(
loc=gauss[1][..., : self.ENCODING, k],
scale=softplus(gauss[1][..., self.ENCODING:, k]),
scale=softplus(gauss[1][..., self.ENCODING :, k]),
),
reinterpreted_batch_ndims=1,
)
......
......@@ -694,7 +694,10 @@ def rule_based_video(
animal_ids = coordinates._animal_ids
undercond = "_" if len(animal_ids) > 1 else ""
vid_name = re.findall("(.*?)_", tracks[vid_index])[0]
try:
vid_name = re.findall("(.*?)_", tracks[vid_index])[0]
except IndexError:
vid_name = tracks[vid_index]
coords = coordinates.get_coords()[vid_name]
speeds = coordinates.get_coords(speed=1)[vid_name]
......@@ -754,7 +757,7 @@ def rule_based_video(
# Define the FPS. Also frame size is passed.
writer = cv2.VideoWriter()
writer.open(
re.findall("(.*?)_", tracks[vid_index])[0] + "_tagged.avi",
vid_name + "_tagged.avi",
cv2.VideoWriter_fourcc(*"MJPG"),
hparams["fps"],
(frame.shape[1], frame.shape[0]),
......
......@@ -161,6 +161,22 @@ def test_run(nodes, ego):
assert type(prun) == deepof.data.coordinates
def test_get_rule_based_annotation():
prun = deepof.data.project(
path=os.path.join(".", "tests", "test_examples"),
arena="circular",
arena_dims=tuple([380]),
video_format=".mp4",
table_format=".h5",
).run()
prun = prun.rule_based_annotation()
assert type(prun) == deepof.data.table_dict
assert prun._type == "rule-based"
@settings(deadline=None)
@given(
nodes=st.integers(min_value=0, max_value=1),
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment