diff --git a/source/utils.py b/source/utils.py index 59c1e743ec49195ea2d4539fb1b3f88664d328f9..1d02602e27ab213aeb8802ec15b2fd6aa20b56a7 100644 --- a/source/utils.py +++ b/source/utils.py @@ -225,12 +225,12 @@ def smooth_boolean_array(a: np.array) -> np.array: def rolling_window(a: np.array, window_size: int, window_step: int) -> np.array: """Returns a 3D numpy.array with a sliding-window extra dimension - Parameters: - - a (2D np.array): N (instances) * m (features) shape + Parameters: + - a (2D np.array): N (instances) * m (features) shape - Returns: - - rolled_a (3D np.array): - N (sliding window instances) * l (sliding window size) * m (features)""" + Returns: + - rolled_a (3D np.array): + N (sliding window instances) * l (sliding window size) * m (features)""" shape = (a.shape[0] - window_size + 1, window_size) + a.shape[1:] strides = (a.strides[0],) + a.strides @@ -241,14 +241,24 @@ def rolling_window(a: np.array, window_size: int, window_step: int) -> np.array: @jit -def smooth_mult_trajectory(series, alpha=0.15): - """smooths a trajectory using exponentially weighted averages""" +def smooth_mult_trajectory(series: np.array, alpha: float = 0.15) -> np.array: + """Returns a smooths a trajectory using exponentially weighted averages + + Parameters: + - series (np.array): 2D trajectory array with N (instances) * m (features) + - alpha (float): 0 <= alpha <= 1; indicates the weight assigned to the current observation. + higher (alpha~1) indicates less smoothing; lower indicates more (alpha~0) + + Returns: + - smoothed_series (np.array): smoothed version of the input, with equal shape""" result = [series[0]] for n in range(len(series)): result.append(alpha * series[n] + (1 - alpha) * result[n - 1]) - return np.array(result) + smoothed_series = np.array(result) + + return smoothed_series ##### IMAGE/VIDEO PROCESSING FUNCTIONS ##### diff --git a/test_deepof/test_utils.py b/test_deepof/test_utils.py index cac6997a119dd6a4f602f3dd2ed0f932fd0152d1..69459ccba27f5ca9424fe02ef23dadddefdfa8fa 100644 --- a/test_deepof/test_utils.py +++ b/test_deepof/test_utils.py @@ -220,5 +220,5 @@ def test_rolling_window(a, window): rolled_shape = rolling_window(a, window_size, window_step).shape - assert len(rolled_shape) == a.shape + 1 + assert len(rolled_shape) == len(a.shape) + 1 assert rolled_shape[1] == window_size