Skip to content
Snippets Groups Projects
Commit ff4f61af authored by Philipp Arras's avatar Philipp Arras
Browse files

Remove everything jax-related

parent 84fb9ba7
No related branches found
No related tags found
1 merge request!36Pypi
......@@ -8,7 +8,7 @@ RUN apt-get update -qq && apt-get install -qq git
RUN apt-get update -qq && apt-get install -qq python3-pip
RUN pip3 install pybind11
# Optional dependencies
RUN pip3 install astropy jax jaxlib
RUN pip3 install astropy
RUN apt-get install -qq python3-mpi4py
# Testing dependencies
RUN apt-get install -qq python3-pytest-cov
......
......@@ -35,7 +35,6 @@ Automatically installed by installation script:
Optional dependencies:
- astropy
- jax, jaxlib
## Installation
......@@ -48,7 +47,7 @@ First install the necessary dependencies, for example via:
Optionally install afterwards:
pip3 install astropy jax jaxlib
pip3 install astropy
Finally, clone the resolve repository and install resolve on your system:
......
......@@ -34,7 +34,6 @@ dom = rve.default_sky_domain(pdom=pdom, sdom=sdom)
dom = {kk: dom[1:] for kk in pdom.labels}
tgt = rve.default_sky_domain(pdom=pdom, sdom=sdom)
opold = rve.polarization_matrix_exponential(tgt)
opold_jax = rve.polarization_matrix_exponential(tgt, jax=True)
for nthreads in [1, 4, 8]:
op = rve.polarization_matrix_exponential_mf2f(dom, nthreads)
......@@ -43,6 +42,3 @@ for nthreads in [1, 4, 8]:
print()
print("Old implementation")
ift.exec_time(opold)
print()
print("Old implementation (jax)")
ift.exec_time(opold_jax)
......@@ -2,5 +2,5 @@
requires-python = ">=3.7"
[build-system]
requires = ["setuptools >= 40.6.0", "pybind11 >= 2.6.0", "ducc0", "matplotlib", "h5py", "python-casacore", "scipy", "nifty8"]
requires = ["setuptools >= 40.6.0", "pybind11 >= 2.6.0", "ducc0", "matplotlib", "h5py", "nifty8"]
build-backend = "setuptools.build_meta"
......@@ -27,7 +27,10 @@ from .polarization import Polarization
def ms_table(path):
try:
from casacore.tables import table
except ImportError:
raise ImportError("You need to install python-casacore for working with measurement sets")
return table(path, readonly=True, ack=False)
......
......@@ -55,7 +55,7 @@ def polarization_matrix_exponential_mf2f(domain, nthreads=1):
return Pybind11Operator(domain, target, f(nthreads))
def polarization_matrix_exponential(domain, jax=False):
def polarization_matrix_exponential(domain):
"""
Deprecated.
......@@ -74,9 +74,6 @@ def polarization_matrix_exponential(domain, jax=False):
if pdom.labels_eq("I"):
return ift.ScalingOperator(domain, 1.).exp()
if jax:
return _jax_pol(domain)
mfs = MultiFieldStacker(domain, 0, domain[0].labels)
op = PolarizationMatrixExponential(mfs.domain)
return mfs @ op @ mfs.inverse
......@@ -116,36 +113,3 @@ class PolarizationMatrixExponential(ift.Operator):
if with_v:
return I.unite(U.unite(Q.unite(V)))
return I.unite(U.unite(Q))
def _jax_pol(domain):
from jax.numpy import cosh, empty, exp, float64, sinh, sqrt, zeros
domain = ift.makeDomain(domain)
pdom = domain[0]
assert isinstance(pdom, PolarizationSpace)
with_v = "V" in pdom.labels
I = pdom.label2index("I")
Q = pdom.label2index("Q")
U = pdom.label2index("U")
if with_v:
V = pdom.label2index("V")
def func(x):
sq = x[Q] ** 2 + x[U] ** 2
if with_v:
sq += x[V] ** 2
log_p = sqrt(sq)
tmpi = exp(x[I])
res = empty(domain.shape, float64)
res = res.at[I].set(tmpi * cosh(log_p))
tmp = tmpi * sinh(log_p) / log_p
res = res.at[U].set(tmp * x[U])
res = res.at[Q].set(tmp * x[Q])
if with_v:
res = res.at[V].set(tmp * x[V])
return res
return ift.JaxOperator(domain, domain, func)
......@@ -33,17 +33,6 @@ from .simple_operators import MultiFieldStacker
from .util import assert_sky_domain
def _has_jax():
try:
import jax
print("Will use jax in double precision on the CPU")
jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
return True
except ImportError:
return False
def sky_model_diffuse(cfg, observations=[], nthreads=1):
sdom = _spatial_dom(cfg)
pdom = PolarizationSpace(cfg["polarization"].split(","))
......
......@@ -142,9 +142,8 @@ setup(
packages=find_packages(include=["resolve", "resolve.*", "resolve_support", "resolve_support.*"]),
zip_safe=True,
dependency_links=[],
install_requires=["ducc0", "matplotlib", "h5py", "python-casacore", "scipy", "nifty8"],
extras_require={"full": ("jax", "astropy", "pytest", "pytest-cov", "mpi4py"),
"mpi": ("mpi4py",)},
install_requires=["ducc0", "matplotlib", "h5py", "scipy", "nifty8"],
extras_require={"full": ("astropy", "pytest", "pytest-cov", "mpi4py", "python-casacore")},
ext_modules=extensions,
entry_points={"console_scripts":
[
......
import resolve as rve
import numpy as np
import jax.numpy as jnp
import resolve_support
import nifty8 as ift
......
......@@ -38,14 +38,11 @@ restdom = list2fixture([[ift.UnstructuredDomain(7)],
def test_different_implementations(pdom, restdom):
dom = tuple((pdom,)) + tuple(restdom)
op0 = rve.polarization_matrix_exponential(dom, False)
op1 = rve.polarization_matrix_exponential(dom, True)
op0 = rve.polarization_matrix_exponential(dom)
loc = ift.from_random(op0.domain)
ift.extra.check_operator(op0, loc, ntries=3)
ift.extra.check_operator(op1, loc, ntries=3)
ift.extra.assert_allclose(op0(loc), op1(loc))
if pdom.labels_eq(["I", "Q", "U", "V"]):
op2 = rve.polarization_matrix_exponential_mf2f({kk: restdom for kk in pdom.labels})
......@@ -57,19 +54,9 @@ def test_different_implementations(pdom, restdom):
@pmp("pol", ("I", ["I", "Q", "U"], ["I", "Q", "U", "V"]))
def test_polarization(pol):
dom = rve.PolarizationSpace(pol), rve.IRGSpace([0]), rve.IRGSpace([0]), ift.RGSpace([10, 20])
op = rve.polarization_matrix_exponential(dom, False)
op = rve.polarization_matrix_exponential(dom)
pos = ift.from_random(op.domain)
ift.extra.check_operator(op, pos, ntries=5)
try:
op_jax = rve.polarization_matrix_exponential(dom, True)
assert op.domain is op_jax.domain
assert op.target is op_jax.target
ift.extra.assert_allclose(op(pos), op_jax(pos))
ift.extra.check_operator(op_jax, pos, ntries=5)
except ImportError:
pass
def test_polarization_matrix_exponential():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment