-
Philipp Arras authored
Currently, the current jax integration into the old nifty features a performance regression. If no jax is installed on the system, python tries to import jax **many many times** during an optimization run. As an example: Profiling `demos/old_nifty/getting_started_3.py` for ~40 seconds gives me: ``` 118722685 function calls (118252629 primitive calls) in 41.364 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 2972947 8.361 0.000 26.735 0.000 <frozen importlib._bootstrap_external>:1597(find_spec) 2997031 6.230 0.000 6.230 0.000 {built-in method posix.stat} 14874427 5.900 0.000 10.026 0.000 <frozen importlib._bootstrap_external>:126(_path_join) 29768056 2.474 0.000 2.474 0.000 {method 'rstrip' of 'str' objects} 14912280/14912255 1.663 0.000 1.666 0.000 {method 'join' of 'str' objects} 44175 1.403 0.000 28.693 0.001 <frozen importlib._bootstrap_external>:1495(_get_spec) 137/125 1.202 0.009 1.219 0.010 {built-in method _imp.exec_dynamic} 14859814 1.185 0.000 1.185 0.000 <frozen importlib._bootstrap>:491(_verbose_message) 1417 0.567 0.000 0.567 0.000 {method 'dot' of 'numpy.ndarray' objects} 2996914 0.446 0.000 6.676 0.000 <frozen importlib._bootstrap_external>:140(_path_stat) 5919430 0.445 0.000 0.485 0.000 {built-in method builtins.isinstance} 119292 0.392 0.000 0.559 0.000 _stride_tricks_impl.py:340(_broadcast_to) 6362 0.390 0.000 0.390 0.000 {built-in method ducc0.fft.genuine_hartley} 296579 0.384 0.000 0.970 0.000 field.py:50(__init__) ``` This indicates that a significant fraction of the time is spent trying to import jax. After the patch we have: ``` 109668239 function calls (107820951 primitive calls) in 46.655 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 369822 2.190 0.000 4.989 0.000 chain_operator.py:43(simplify) 4539 1.824 0.000 1.824 0.000 {method 'dot' of 'numpy.ndarray' objects} 19048364 1.814 0.000 1.990 0.000 {built-in method builtins.isinstance} 550077 1.693 0.000 2.439 0.000 _stride_tricks_impl.py:340(_broadcast_to) 27484 1.691 0.000 1.691 0.000 {built-in method ducc0.fft.genuine_hartley} 1339982 1.686 0.000 4.271 0.000 field.py:50(__init__) 578749 1.347 0.000 4.067 0.000 field.py:689(_binary_op) 168/144 1.223 0.007 1.249 0.009 {built-in method _imp.exec_dynamic} 347476 1.022 0.000 1.981 0.000 diagonal_operator.py:141(apply) 153638 0.817 0.000 0.817 0.000 {built-in method ducc0.misc.vdot} ``` Since we have nifty8.re as a proper jax implementation of nifty, I suggest to remove the legacy nifty8-to-jax interface.
Philipp Arras authoredCurrently, the current jax integration into the old nifty features a performance regression. If no jax is installed on the system, python tries to import jax **many many times** during an optimization run. As an example: Profiling `demos/old_nifty/getting_started_3.py` for ~40 seconds gives me: ``` 118722685 function calls (118252629 primitive calls) in 41.364 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 2972947 8.361 0.000 26.735 0.000 <frozen importlib._bootstrap_external>:1597(find_spec) 2997031 6.230 0.000 6.230 0.000 {built-in method posix.stat} 14874427 5.900 0.000 10.026 0.000 <frozen importlib._bootstrap_external>:126(_path_join) 29768056 2.474 0.000 2.474 0.000 {method 'rstrip' of 'str' objects} 14912280/14912255 1.663 0.000 1.666 0.000 {method 'join' of 'str' objects} 44175 1.403 0.000 28.693 0.001 <frozen importlib._bootstrap_external>:1495(_get_spec) 137/125 1.202 0.009 1.219 0.010 {built-in method _imp.exec_dynamic} 14859814 1.185 0.000 1.185 0.000 <frozen importlib._bootstrap>:491(_verbose_message) 1417 0.567 0.000 0.567 0.000 {method 'dot' of 'numpy.ndarray' objects} 2996914 0.446 0.000 6.676 0.000 <frozen importlib._bootstrap_external>:140(_path_stat) 5919430 0.445 0.000 0.485 0.000 {built-in method builtins.isinstance} 119292 0.392 0.000 0.559 0.000 _stride_tricks_impl.py:340(_broadcast_to) 6362 0.390 0.000 0.390 0.000 {built-in method ducc0.fft.genuine_hartley} 296579 0.384 0.000 0.970 0.000 field.py:50(__init__) ``` This indicates that a significant fraction of the time is spent trying to import jax. After the patch we have: ``` 109668239 function calls (107820951 primitive calls) in 46.655 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 369822 2.190 0.000 4.989 0.000 chain_operator.py:43(simplify) 4539 1.824 0.000 1.824 0.000 {method 'dot' of 'numpy.ndarray' objects} 19048364 1.814 0.000 1.990 0.000 {built-in method builtins.isinstance} 550077 1.693 0.000 2.439 0.000 _stride_tricks_impl.py:340(_broadcast_to) 27484 1.691 0.000 1.691 0.000 {built-in method ducc0.fft.genuine_hartley} 1339982 1.686 0.000 4.271 0.000 field.py:50(__init__) 578749 1.347 0.000 4.067 0.000 field.py:689(_binary_op) 168/144 1.223 0.007 1.249 0.009 {built-in method _imp.exec_dynamic} 347476 1.022 0.000 1.981 0.000 diagonal_operator.py:141(apply) 153638 0.817 0.000 0.817 0.000 {built-in method ducc0.misc.vdot} ``` Since we have nifty8.re as a proper jax implementation of nifty, I suggest to remove the legacy nifty8-to-jax interface.