Commit ff1ec1ed authored by lucas_miranda's avatar lucas_miranda
Browse files

added inplace alignment on deepof.data.coordinates.get_coords() on data.py

parent c2d41b0f
Pipeline #83429 passed with stage
in 14 minutes and 20 seconds
......@@ -56,7 +56,7 @@ class project:
path: str = deepof.utils.os.path.join("."),
exp_conditions: dict = None,
arena: str = "circular",
smooth_alpha: float = 1.0,
smooth_alpha: float = 0.99,
arena_dims: tuple = (1,),
model: str = "mouse_topview",
animal_ids: List = tuple([""]),
......@@ -84,7 +84,7 @@ class project:
self.video_format = video_format
self.arena = arena
self.arena_dims = arena_dims
self.smooth_alpha = 0.99
self.smooth_alpha = smooth_alpha
self.scales = self.get_scale
self.animal_ids = animal_ids
......@@ -421,6 +421,7 @@ class coordinates:
speed: int = 0,
length: str = None,
align: bool = False,
align_inplace: bool = False,
) -> Table_dict:
"""
Returns a table_dict object with the coordinates of each animal as values.
......@@ -436,6 +437,8 @@ class coordinates:
of the stored dataframes will reflect the actual timing in the video.
- align (bool): selects the body part to which later processes will align the frames with
(see preprocess in table_dict documentation).
- align_inplace (bool): Only valid if align is set. Aligns the vector that goes from the origin to
the selected body part with the y axis, for all time points.
Returns:
tab_dict (Table_dict): table_dict object containing all the computed information
......@@ -521,7 +524,16 @@ class coordinates:
(align, ("phi" if polar else "x")),
(align, ("rho" if polar else "y")),
] + columns
tabs[key] = tab[columns]
tab = tab[columns]
tabs[key] = tab
if align_inplace and polar is False:
index = tab.columns
tab = pd.DataFrame(
deepof.utils.align_trajectories(np.array(tab), mode="all")
)
tab.columns = index
tabs[key] = tab
return table_dict(
tabs,
......@@ -739,7 +751,13 @@ class table_dict(dict):
# noinspection PyTypeChecker
def plot_heatmaps(
self, bodyparts: list, save: bool = False, i: int = 0
self,
bodyparts: list,
xlim: float = None,
ylim: float = None,
save: bool = False,
i: int = 0,
dpi: int = 100,
) -> plt.figure:
"""Plots heatmaps of the specified body parts (bodyparts) of the specified animal (i)"""
......@@ -753,19 +771,14 @@ class table_dict(dict):
warnings.warn("Heatmaps look better if you center the data")
if self._arena == "circular":
x_lim = (
[-self._arena_dims[i][2] / 2, self._arena_dims[i][2] / 2]
if self._center
else [0, self._arena_dims[i][0]]
)
y_lim = (
[-self._arena_dims[i][2] / 2, self._arena_dims[i][2] / 2]
if self._center
else [0, self._arena_dims[i][1]]
)
heatmaps = deepof.visuals.plot_heatmap(
list(self.values())[i], bodyparts, xlim=x_lim, ylim=y_lim, save=save,
list(self.values())[i],
bodyparts,
xlim=xlim,
ylim=ylim,
save=save,
dpi=dpi,
)
return heatmaps
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment