Skip to content
Snippets Groups Projects
  • Philipp Arras's avatar
    5c4943d4
    Remove jax from old nifty · 5c4943d4
    Philipp Arras authored
    Currently, the current jax integration into the old nifty features a
    performance regression. If no jax is installed on the system, python
    tries to import jax **many many times** during an optimization run.
    
    As an example: Profiling `demos/old_nifty/getting_started_3.py` for ~40
    seconds gives me:
    
    ```
             118722685 function calls (118252629 primitive calls) in 41.364 seconds
    
       Ordered by: internal time
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      2972947    8.361    0.000   26.735    0.000 <frozen importlib._bootstrap_external>:1597(find_spec)
      2997031    6.230    0.000    6.230    0.000 {built-in method posix.stat}
     14874427    5.900    0.000   10.026    0.000 <frozen importlib._bootstrap_external>:126(_path_join)
     29768056    2.474    0.000    2.474    0.000 {method 'rstrip' of 'str' objects}
    14912280/14912255    1.663    0.000    1.666    0.000 {method 'join' of 'str' objects}
        44175    1.403    0.000   28.693    0.001 <frozen importlib._bootstrap_external>:1495(_get_spec)
      137/125    1.202    0.009    1.219    0.010 {built-in method _imp.exec_dynamic}
     14859814    1.185    0.000    1.185    0.000 <frozen importlib._bootstrap>:491(_verbose_message)
         1417    0.567    0.000    0.567    0.000 {method 'dot' of 'numpy.ndarray' objects}
      2996914    0.446    0.000    6.676    0.000 <frozen importlib._bootstrap_external>:140(_path_stat)
      5919430    0.445    0.000    0.485    0.000 {built-in method builtins.isinstance}
       119292    0.392    0.000    0.559    0.000 _stride_tricks_impl.py:340(_broadcast_to)
         6362    0.390    0.000    0.390    0.000 {built-in method ducc0.fft.genuine_hartley}
       296579    0.384    0.000    0.970    0.000 field.py:50(__init__)
    ```
    
    This indicates that a significant fraction of the time is spent trying
    to import jax. After the patch we have:
    
    ```
             109668239 function calls (107820951 primitive calls) in 46.655 seconds
    
       Ordered by: internal time
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       369822    2.190    0.000    4.989    0.000 chain_operator.py:43(simplify)
         4539    1.824    0.000    1.824    0.000 {method 'dot' of 'numpy.ndarray' objects}
     19048364    1.814    0.000    1.990    0.000 {built-in method builtins.isinstance}
       550077    1.693    0.000    2.439    0.000 _stride_tricks_impl.py:340(_broadcast_to)
        27484    1.691    0.000    1.691    0.000 {built-in method ducc0.fft.genuine_hartley}
      1339982    1.686    0.000    4.271    0.000 field.py:50(__init__)
       578749    1.347    0.000    4.067    0.000 field.py:689(_binary_op)
      168/144    1.223    0.007    1.249    0.009 {built-in method _imp.exec_dynamic}
       347476    1.022    0.000    1.981    0.000 diagonal_operator.py:141(apply)
       153638    0.817    0.000    0.817    0.000 {built-in method ducc0.misc.vdot}
    ```
    
    Since we have nifty8.re as a proper jax implementation of nifty, I
    suggest to remove the legacy nifty8-to-jax interface.
    5c4943d4
    History
    Remove jax from old nifty
    Philipp Arras authored
    Currently, the current jax integration into the old nifty features a
    performance regression. If no jax is installed on the system, python
    tries to import jax **many many times** during an optimization run.
    
    As an example: Profiling `demos/old_nifty/getting_started_3.py` for ~40
    seconds gives me:
    
    ```
             118722685 function calls (118252629 primitive calls) in 41.364 seconds
    
       Ordered by: internal time
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      2972947    8.361    0.000   26.735    0.000 <frozen importlib._bootstrap_external>:1597(find_spec)
      2997031    6.230    0.000    6.230    0.000 {built-in method posix.stat}
     14874427    5.900    0.000   10.026    0.000 <frozen importlib._bootstrap_external>:126(_path_join)
     29768056    2.474    0.000    2.474    0.000 {method 'rstrip' of 'str' objects}
    14912280/14912255    1.663    0.000    1.666    0.000 {method 'join' of 'str' objects}
        44175    1.403    0.000   28.693    0.001 <frozen importlib._bootstrap_external>:1495(_get_spec)
      137/125    1.202    0.009    1.219    0.010 {built-in method _imp.exec_dynamic}
     14859814    1.185    0.000    1.185    0.000 <frozen importlib._bootstrap>:491(_verbose_message)
         1417    0.567    0.000    0.567    0.000 {method 'dot' of 'numpy.ndarray' objects}
      2996914    0.446    0.000    6.676    0.000 <frozen importlib._bootstrap_external>:140(_path_stat)
      5919430    0.445    0.000    0.485    0.000 {built-in method builtins.isinstance}
       119292    0.392    0.000    0.559    0.000 _stride_tricks_impl.py:340(_broadcast_to)
         6362    0.390    0.000    0.390    0.000 {built-in method ducc0.fft.genuine_hartley}
       296579    0.384    0.000    0.970    0.000 field.py:50(__init__)
    ```
    
    This indicates that a significant fraction of the time is spent trying
    to import jax. After the patch we have:
    
    ```
             109668239 function calls (107820951 primitive calls) in 46.655 seconds
    
       Ordered by: internal time
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
       369822    2.190    0.000    4.989    0.000 chain_operator.py:43(simplify)
         4539    1.824    0.000    1.824    0.000 {method 'dot' of 'numpy.ndarray' objects}
     19048364    1.814    0.000    1.990    0.000 {built-in method builtins.isinstance}
       550077    1.693    0.000    2.439    0.000 _stride_tricks_impl.py:340(_broadcast_to)
        27484    1.691    0.000    1.691    0.000 {built-in method ducc0.fft.genuine_hartley}
      1339982    1.686    0.000    4.271    0.000 field.py:50(__init__)
       578749    1.347    0.000    4.067    0.000 field.py:689(_binary_op)
      168/144    1.223    0.007    1.249    0.009 {built-in method _imp.exec_dynamic}
       347476    1.022    0.000    1.981    0.000 diagonal_operator.py:141(apply)
       153638    0.817    0.000    0.817    0.000 {built-in method ducc0.misc.vdot}
    ```
    
    Since we have nifty8.re as a proper jax implementation of nifty, I
    suggest to remove the legacy nifty8-to-jax interface.