diff --git a/src/re/lanczos.py b/src/re/lanczos.py index 06ad881dcff3b266666671e4f62d08bf92511f70..df78f4fb6aa421b16f821b79199c0a23da624437 100644 --- a/src/re/lanczos.py +++ b/src/re/lanczos.py @@ -26,6 +26,10 @@ def lanczos_tridiag( v = v / jnp.linalg.norm(v) vecs = vecs.at[0].set(v) + # TODO + # * use `forest_util.dot` in favor of plain `jnp.dot` + # * remove all reshapes as they are unnecessary + # Zeroth iteration w = mat(v) if w.shape != shape_dtype_struct.shape: