Skip to content

Sunset nifty2jax (and jax_expr)

Philipp Arras requested to merge g-philipp/nifty:old_nifty_cleanup into NIFTy_8

Currently, the current jax integration into the old nifty features a performance regression. If no jax is installed on the system, python tries to import jax many many times during an optimization run.

As an example: Profiling demos/old_nifty/getting_started_3.py for ~40 seconds gives me:

         118722685 function calls (118252629 primitive calls) in 41.364 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
  2972947    8.361    0.000   26.735    0.000 <frozen importlib._bootstrap_external>:1597(find_spec)
  2997031    6.230    0.000    6.230    0.000 {built-in method posix.stat}
 14874427    5.900    0.000   10.026    0.000 <frozen importlib._bootstrap_external>:126(_path_join)
 29768056    2.474    0.000    2.474    0.000 {method 'rstrip' of 'str' objects}
14912280/14912255    1.663    0.000    1.666    0.000 {method 'join' of 'str' objects}
    44175    1.403    0.000   28.693    0.001 <frozen importlib._bootstrap_external>:1495(_get_spec)
  137/125    1.202    0.009    1.219    0.010 {built-in method _imp.exec_dynamic}
 14859814    1.185    0.000    1.185    0.000 <frozen importlib._bootstrap>:491(_verbose_message)
     1417    0.567    0.000    0.567    0.000 {method 'dot' of 'numpy.ndarray' objects}
  2996914    0.446    0.000    6.676    0.000 <frozen importlib._bootstrap_external>:140(_path_stat)
  5919430    0.445    0.000    0.485    0.000 {built-in method builtins.isinstance}
   119292    0.392    0.000    0.559    0.000 _stride_tricks_impl.py:340(_broadcast_to)
     6362    0.390    0.000    0.390    0.000 {built-in method ducc0.fft.genuine_hartley}
   296579    0.384    0.000    0.970    0.000 field.py:50(__init__)

This indicates that a significant fraction of the time is spent trying to import jax. After the patch we have:

         109668239 function calls (107820951 primitive calls) in 46.655 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   369822    2.190    0.000    4.989    0.000 chain_operator.py:43(simplify)
     4539    1.824    0.000    1.824    0.000 {method 'dot' of 'numpy.ndarray' objects}
 19048364    1.814    0.000    1.990    0.000 {built-in method builtins.isinstance}
   550077    1.693    0.000    2.439    0.000 _stride_tricks_impl.py:340(_broadcast_to)
    27484    1.691    0.000    1.691    0.000 {built-in method ducc0.fft.genuine_hartley}
  1339982    1.686    0.000    4.271    0.000 field.py:50(__init__)
   578749    1.347    0.000    4.067    0.000 field.py:689(_binary_op)
  168/144    1.223    0.007    1.249    0.009 {built-in method _imp.exec_dynamic}
   347476    1.022    0.000    1.981    0.000 diagonal_operator.py:141(apply)
   153638    0.817    0.000    0.817    0.000 {built-in method ducc0.misc.vdot}

Since we have nifty8.re as a proper jax implementation of nifty, I suggest to remove the legacy nifty8-to-jax interface.

CC: @pfrank @gedenhof @mtr @veberle @ensslint

Merge request reports

Loading