diff --git a/src/re/lanczos.py b/src/re/lanczos.py
index 06ad881dcff3b266666671e4f62d08bf92511f70..df78f4fb6aa421b16f821b79199c0a23da624437 100644
--- a/src/re/lanczos.py
+++ b/src/re/lanczos.py
@@ -26,6 +26,10 @@ def lanczos_tridiag(
     v = v / jnp.linalg.norm(v)
     vecs = vecs.at[0].set(v)
 
+    # TODO
+    # * use `forest_util.dot` in favor of plain `jnp.dot`
+    # * remove all reshapes as they are unnecessary
+
     # Zeroth iteration
     w = mat(v)
     if w.shape != shape_dtype_struct.shape: