Wrong result of FFT.jax_expr
The inverse application of the FFT.jax_expr
seems to give wrong results. Here is a snippet to reproduce the error:
import numpy as np
import nifty8 as ift
sp = ift.RGSpace((100, 100), distances=0.1)
FFT = ift.FFTOperator(sp)
inp = ift.from_random(FFT.domain)
res_nifty = FFT(inp)
res_jax = FFT.jax_expr(inp.val)
assert(np.allclose(res_nifty.val, res_jax)) # works
res_nifty_inv = FFT.inverse(res_nifty)
res_jax_inv = FFT.jax_expr(inp.val, inverse=True)
assert(np.allclose(res_nifty_inv.val, res_jax_inv)) # fails
res_nifty_inv = FFT.inverse(res_nifty)
res_jax_inv = FFT.inverse.jax_expr(inp.val, inverse=True)
assert(np.allclose(res_nifty_inv.val, res_jax_inv)) # fails
assert(np.allclose(res_nifty_inv.val, inp.val)) # works
assert(np.allclose(res_jax, inp.val)) # fails