diff --git a/.envrc b/.envrc new file mode 100644 index 0000000000000000000000000000000000000000..3550a30f2de389e537ee40ca5e64a77dc185c79b --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.gitignore b/.gitignore index 9b186399aed4de38693e315989de279f06df4243..4d73ff0428c0114389e7e365854a1315316fb2ea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# nix-related +result +.nix-venv + docs/source/user/getting_started_0.rst docs/source/user/custom_nonlinearities.rst docs/source/user/getting_started_4_CorrelatedFields.rst diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 6a0b4f63c508d92a604763e8e3bc9f62c5948bba..670e57a94c99e034d3f97a0017c994b53951c667 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -46,7 +46,7 @@ build_docker_from_cache: test_serial: stage: test script: - - pytest-3 -q --cov=nifty8 test + - pytest -q --cov=nifty8 test - > python3 -m coverage report --omit "*plot*" | tee coverage.txt - > @@ -57,7 +57,7 @@ test_mpi: variables: OMPI_MCA_btl_vader_single_copy_mechanism: none script: - - mpiexec -n 2 --bind-to none pytest-3 -q test/test_mpi + - mpiexec -n 2 --bind-to none pytest -q test/test_mpi pages: stage: release diff --git a/Dockerfile b/Dockerfile index fa5d96a46927a9f516f4c9b818b33ed5f295d3de..41d6389b6ef9457b90833bf2d952aa1f0bf34338 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,17 +3,20 @@ FROM debian:stable-slim RUN apt-get update && apt-get install -y \ # Needed for setup git python3-pip \ - # Packages needed for NIFTy - python3-scipy \ # Documentation build dependencies dvipng texlive-latex-base texlive-latex-extra \ + # Dependency of mpi4py + libopenmpi-dev \ + && rm -rf /var/lib/apt/lists/* +RUN DUCC0_OPTIMIZATION=portable pip3 install \ + # Packages needed for NIFTy + scipy \ + # Optional nifty dependencies + matplotlib h5py astropy ducc0 jax jaxlib mpi4py \ # Testing dependencies - python3-pytest-cov jupyter \ - # Optional NIFTy dependencies - python3-mpi4py python3-matplotlib python3-h5py \ - # more optional NIFTy dependencies - && DUCC0_OPTIMIZATION=portable pip3 install astropy ducc0 jupyter jax jaxlib sphinx pydata-sphinx-theme jupytext \ - && rm -rf /var/lib/apt/lists/* + pytest pytest-cov \ + # Documentation build dependencies + jupyter nbconvert jupytext sphinx pydata-sphinx-theme # Set matplotlib backend ENV MPLBACKEND agg diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..c025e56baf9f47810075222f2dbe8f193314a182 --- /dev/null +++ b/flake.lock @@ -0,0 +1,42 @@ +{ + "nodes": { + "flake-utils": { + "locked": { + "lastModified": 1659877975, + "narHash": "sha256-zllb8aq3YO3h8B/U0/J1WBgAL8EX5yWf5pMj3G0NAmc=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "c0e246b9b83f637f4681389ecabcb2681b4f3af0", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1663688007, + "narHash": "sha256-Ei7MJAYHTGl+reg0FpAhYddtZVVNLH3FOZrMSTTznaU=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "778d8ad2d838ea605b6abcb75fa0e39331f6f60c", + "type": "github" + }, + "original": { + "owner": "nixos", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..ceb87184e4a9ae761e677e712b5ae8b8494f7c67 --- /dev/null +++ b/flake.nix @@ -0,0 +1,41 @@ +{ + description = "Numerical Information Field Theory"; + + inputs = { + nixpkgs.url = "github:nixos/nixpkgs"; + flake-utils.url = "github:numtide/flake-utils"; + }; + outputs = { self, nixpkgs, flake-utils }: + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { inherit system; }; + myPyPkgs = pkgs.python3Packages; + + req.minimal = with myPyPkgs; [ numpy scipy ducc0 ]; + req.dev = with myPyPkgs; [ pytest pytest-cov matplotlib ]; + req.mpi = [ myPyPkgs.mpi4py pkgs.openmpi pkgs.openssh ]; + req.jax = with myPyPkgs; [ jax jaxlib ]; + req.rest = with myPyPkgs; [ astropy ]; + + req.docs = with myPyPkgs; [ sphinx jupyter jupytext ]; + # TODO add pydata-sphinx-theme + in { + packages.default = myPyPkgs.buildPythonPackage { + pname = "nifty8"; + version = "8.0"; # TODO Set this automatically + src = ./.; + nativeBuildInputs = req.minimal; + checkInputs = [ myPyPkgs.pytestCheckHook ] ++ req.dev ++ req.mpi; + pytestFlagsArray = [ "test" ]; + pythonImportsCheck = [ "nifty8" ]; + }; + + # TODO Add version with MPI, jax and both + + devShells.default = pkgs.mkShell { + nativeBuildInputs = with myPyPkgs; [ pip venvShellHook ] ++ req.minimal ++ req.dev; + # ( pkgs.lib.attrValues req ) ; + venvDir = "./.nix-nifty-venv"; + }; + }); +} diff --git a/src/library/nft.py b/src/library/nft.py index 0dd7932c2a92dd570bfc4c5bdc2672517b31a127..729a422634add322d7ea5a004ff09eb7b8f76c9d 100644 --- a/src/library/nft.py +++ b/src/library/nft.py @@ -90,6 +90,10 @@ class Nufft(LinearOperator): Requested precision, defaults to 2e-10. """ def __init__(self, target, pos, eps=2e-10): + try: + from ducc0.nufft import nu2u, u2nu + except ImportError: + raise ImportError("ducc0 needs to be installed for nifty.Nufft()") self._capability = self.TIMES | self.ADJOINT_TIMES self._target = makeDomain(target) if not isinstance(self._target[0], RGSpace): diff --git a/test/test_re/test_energies.py b/test/test_re/test_energies.py index 9df3cd1b5735164518d7c7df2c913622f719a60a..45a67dc8f99b5b79ea41a2638f5f43ea0d6b6d1c 100644 --- a/test/test_re/test_energies.py +++ b/test/test_re/test_energies.py @@ -1,7 +1,9 @@ #!/usr/bin/env python3 -import jax.numpy as jnp import pytest +pytest.importorskip("jax") + +import jax.numpy as jnp from functools import partial from jax import random from jax.tree_util import tree_map diff --git a/test/test_re/test_hmc_1d_distributions.py b/test/test_re/test_hmc_1d_distributions.py index 1eb520c7211f11b8815a67ec55869b5807082268..d07fb98c2cd0d77e44d4a10adc2ed4dcddc91cfa 100644 --- a/test/test_re/test_hmc_1d_distributions.py +++ b/test/test_re/test_hmc_1d_distributions.py @@ -1,9 +1,11 @@ import sys +import pytest +pytest.importorskip("jax") + from jax import numpy as jnp from jax.scipy import stats from numpy.testing import assert_allclose -import pytest import scipy from scipy.special import comb diff --git a/test/test_re/test_hmc_hashes.py b/test/test_re/test_hmc_hashes.py index 219951905c4d65ed738a74dfddd488ecd5169e03..8968acd7744ecf870caa63997ecaeeba53001eb3 100644 --- a/test/test_re/test_hmc_hashes.py +++ b/test/test_re/test_hmc_hashes.py @@ -1,5 +1,8 @@ import sys +import pytest +pytest.importorskip("jax") + from jax import numpy as jnp from jax.config import config as jax_config from numpy import ndarray diff --git a/test/test_re/test_hmc_leapfrog.py b/test/test_re/test_hmc_leapfrog.py index 8062d2fc97e7374910edb4ff6970c171c0d3b27d..4efa1871fc5af6b31362810e510b90e9a54b47ce 100644 --- a/test/test_re/test_hmc_leapfrog.py +++ b/test/test_re/test_hmc_leapfrog.py @@ -1,4 +1,6 @@ import pytest +pytest.importorskip("jax") + import sys from jax import grad from jax import numpy as jnp diff --git a/test/test_re/test_hmc_pytree.py b/test/test_re/test_hmc_pytree.py index 9a6f1e3c89e3fc1f86819834970cac3506075003..eec04c6d79707178a499e78458e6395d3e9a33d5 100644 --- a/test/test_re/test_hmc_pytree.py +++ b/test/test_re/test_hmc_pytree.py @@ -1,3 +1,6 @@ +import pytest +pytest.importorskip("jax") + from functools import partial from jax import numpy as jnp from jax.tree_util import tree_leaves diff --git a/test/test_re/test_lanczos.py b/test/test_re/test_lanczos.py index abc4edbac7510c90ccb281abba882bfd259ebae0..beea2bbb5a80893d9c4dab289a181f9eedb8d343 100644 --- a/test/test_re/test_lanczos.py +++ b/test/test_re/test_lanczos.py @@ -2,13 +2,15 @@ # SPDX-License-Identifier: GPL-2.0+ OR BSD-2-Clause +import pytest +pytest.importorskip("jax") + from functools import partial import sys from jax import random import jax.numpy as jnp import numpy as np from numpy.testing import assert_allclose -import pytest from scipy.spatial import distance_matrix import nifty8.re as jft diff --git a/test/test_re/test_ncg.py b/test/test_re/test_ncg.py index fbd9737d189c1fa66c4ca5a65582d1c1a839ce52..f096718ca3a07e14c922eeda49d5a0a09fa1cae3 100644 --- a/test/test_re/test_ncg.py +++ b/test/test_re/test_ncg.py @@ -2,10 +2,12 @@ import sys +import pytest +pytest.importorskip("jax") + from jax import random, value_and_grad import jax.numpy as jnp from numpy.testing import assert_allclose -import pytest import nifty8.re as jft diff --git a/test/test_re/test_refine.py b/test/test_re/test_refine.py index 8ad0075073a0a44c1109617d34116fc4f3a83f29..05f8f8be34ac7634fc07bd2797bf6e585cafad20 100644 --- a/test/test_re/test_refine.py +++ b/test/test_re/test_refine.py @@ -5,13 +5,14 @@ from functools import partial import sys -import jax +import pytest +pytest.importorskip("jax") + from jax import random import jax.numpy as jnp from jax.tree_util import Partial import numpy as np from numpy.testing import assert_allclose -import pytest from scipy.spatial import distance_matrix import nifty8.re as jft diff --git a/test/test_re/test_refine_util.py b/test/test_re/test_refine_util.py index 36df85c0d076b512c1c0d541bf3ae8a3e8b94bd6..59c7fcc99fd12c340d58bb12d910dd3e2830f1e6 100644 --- a/test/test_re/test_refine_util.py +++ b/test/test_re/test_refine_util.py @@ -4,9 +4,11 @@ from functools import partial +import pytest +pytest.importorskip("jax") + import jax import numpy as np -import pytest from nifty8.re import refine_chart, refine_util diff --git a/test/test_re/test_stats_distributions.py b/test/test_re/test_stats_distributions.py index 7e9e021f84d5a903f6c23c59a14f8ad7edb6de41..c911f669a4f92a2971ebdeec00a231131e680f20 100644 --- a/test/test_re/test_stats_distributions.py +++ b/test/test_re/test_stats_distributions.py @@ -1,6 +1,8 @@ +import pytest +pytest.importorskip("jax") + import numpy as np from numpy.testing import assert_allclose -import pytest import nifty8.re as jft