Sunset nifty2jax (and jax_expr)
Currently, the current jax integration into the old nifty features a performance regression. If no jax is installed on the system, python tries to import jax many many times during an optimization run.
As an example: Profiling demos/old_nifty/getting_started_3.py
for ~40
seconds gives me:
118722685 function calls (118252629 primitive calls) in 41.364 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
2972947 8.361 0.000 26.735 0.000 <frozen importlib._bootstrap_external>:1597(find_spec)
2997031 6.230 0.000 6.230 0.000 {built-in method posix.stat}
14874427 5.900 0.000 10.026 0.000 <frozen importlib._bootstrap_external>:126(_path_join)
29768056 2.474 0.000 2.474 0.000 {method 'rstrip' of 'str' objects}
14912280/14912255 1.663 0.000 1.666 0.000 {method 'join' of 'str' objects}
44175 1.403 0.000 28.693 0.001 <frozen importlib._bootstrap_external>:1495(_get_spec)
137/125 1.202 0.009 1.219 0.010 {built-in method _imp.exec_dynamic}
14859814 1.185 0.000 1.185 0.000 <frozen importlib._bootstrap>:491(_verbose_message)
1417 0.567 0.000 0.567 0.000 {method 'dot' of 'numpy.ndarray' objects}
2996914 0.446 0.000 6.676 0.000 <frozen importlib._bootstrap_external>:140(_path_stat)
5919430 0.445 0.000 0.485 0.000 {built-in method builtins.isinstance}
119292 0.392 0.000 0.559 0.000 _stride_tricks_impl.py:340(_broadcast_to)
6362 0.390 0.000 0.390 0.000 {built-in method ducc0.fft.genuine_hartley}
296579 0.384 0.000 0.970 0.000 field.py:50(__init__)
This indicates that a significant fraction of the time is spent trying to import jax. After the patch we have:
109668239 function calls (107820951 primitive calls) in 46.655 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
369822 2.190 0.000 4.989 0.000 chain_operator.py:43(simplify)
4539 1.824 0.000 1.824 0.000 {method 'dot' of 'numpy.ndarray' objects}
19048364 1.814 0.000 1.990 0.000 {built-in method builtins.isinstance}
550077 1.693 0.000 2.439 0.000 _stride_tricks_impl.py:340(_broadcast_to)
27484 1.691 0.000 1.691 0.000 {built-in method ducc0.fft.genuine_hartley}
1339982 1.686 0.000 4.271 0.000 field.py:50(__init__)
578749 1.347 0.000 4.067 0.000 field.py:689(_binary_op)
168/144 1.223 0.007 1.249 0.009 {built-in method _imp.exec_dynamic}
347476 1.022 0.000 1.981 0.000 diagonal_operator.py:141(apply)
153638 0.817 0.000 0.817 0.000 {built-in method ducc0.misc.vdot}
Since we have nifty8.re as a proper jax implementation of nifty, I suggest to remove the legacy nifty8-to-jax interface.