diff --git a/demo/imaging_resolve_jax.py b/demo/imaging_resolve_jax.py
index 40581d4ed409d30312ffd01f1b6d1a0126feb241..63154ff3070c91b1d414907d3889ee07ec8697c5 100644
--- a/demo/imaging_resolve_jax.py
+++ b/demo/imaging_resolve_jax.py
@@ -11,8 +11,9 @@ from matplotlib.colors import LogNorm
 import configparser
 from jax import random
 
-response = 'ducc'
-# response = "finu"
+# choose between ducc0 and finufft backend
+response = 'ducc0'
+# response = "finufft"
 
 seed = 42
 key = random.PRNGKey(seed)
@@ -21,7 +22,7 @@ jax.config.update("jax_enable_x64", True)
 
 obs = rve.Observation.load("CYG-ALL-2052-2MHZ_RESOLVE_float64.npz")
 obs = obs.restrict_to_stokesi()
-# obs = obs.average_stokesi()
+obs = obs.average_stokesi()
 obs._weight = 0.1 * obs._weight # scale weights, as they are wrong for this specific dataset
 cfg = configparser.ConfigParser()
 cfg.read("cygnusa_2ghz.cfg")
@@ -32,23 +33,16 @@ sky, additional = jrve.sky_model(cfg["sky"])
 sky_sp = rve.sky_model._spatial_dom(cfg["sky"])
 sky_dom = rve.default_sky_domain(sdom=sky_sp)
 
-if response == "finu":
-    R_finufft = jrve.InterferometryResponseFinuFFT(
-        obs, sky_sp.distances[0], sky_sp.distances[1], 1e-9
-    )
-    signal_response = lambda x: R_finufft(sky(x)[0, 0, 0, :, :])
-elif response == 'ducc':
-    sky_domain_dict = dict(npix_x=sky_sp.shape[0],
-                        npix_y=sky_sp.shape[1],
-                        pixsize_x=sky_sp.distances[0],
-                        pixsize_y=sky_sp.distances[1],
-                        pol_labels=['I'],
-                        times=[0.],
-                        freqs=[0.])
-    R_new = jrve.InterferometryResponse(obs, sky_domain_dict, False, 1e-9)
-    signal_response = lambda x: R_new(sky(x))
-else:
-    raise ValueError()
+
+sky_domain_dict = dict(npix_x=sky_sp.shape[0],
+                    npix_y=sky_sp.shape[1],
+                    pixsize_x=sky_sp.distances[0],
+                    pixsize_y=sky_sp.distances[1],
+                    pol_labels=['I'],
+                    times=[0.],
+                    freqs=[0.])
+R_new = jrve.InterferometryResponse(obs, sky_domain_dict, False, 1e-9, backend=response)
+signal_response = lambda x: R_new(sky(x))
 
 
 nll = jft.Gaussian(obs.vis.val, obs.weight.val).amend(signal_response)
diff --git a/resolve/re/__init__.py b/resolve/re/__init__.py
index 7912af0fb74783716a206ebf5130cc3d80fbfd39..ce2ba46e2dd79aa995eb542f9f9c8aad80c18a68 100644
--- a/resolve/re/__init__.py
+++ b/resolve/re/__init__.py
@@ -1,3 +1,3 @@
 
 from .sky_model import sky_model_diffuse, sky_model_points, sky_model
-from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc, InterferometryResponseOld
\ No newline at end of file
+from .response import InterferometryResponse, InterferometryResponseFinuFFT, InterferometryResponseDucc
\ No newline at end of file
diff --git a/resolve/re/response.py b/resolve/re/response.py
index 43f3c22f4d8fec495d85773337aee969bdbc77b4..1c40d9cb794f8243168c6e4f083e587d44752bce 100644
--- a/resolve/re/response.py
+++ b/resolve/re/response.py
@@ -6,24 +6,25 @@ from functools import partial
 from ..util import dtype_float2complex
 from jax.tree_util import Partial
 
+
 def get_binbounds(coordinates):
     if len(coordinates) == 1:
-            return np.array([-np.inf, np.inf])
+        return np.array([-np.inf, np.inf])
     c = np.array(coordinates)
     bounds = np.empty(self.size + 1)
-    bounds[1:-1] = c[:-1] + 0.5*np.diff(c)
-    bounds[0] = c[0] - 0.5*(c[1] - c[0])
-    bounds[-1] = c[-1] + 0.5*(c[-1] - c[-2])
+    bounds[1:-1] = c[:-1] + 0.5 * np.diff(c)
+    bounds[0] = c[0] - 0.5 * (c[1] - c[0])
+    bounds[-1] = c[-1] + 0.5 * (c[-1] - c[-2])
     return bounds
 
 
 def convert_polarization(inp, inp_pol, out_pol):
-    if inp_pol == ('I',):
-        if out_pol == ('LL', 'RR') or out_pol == ('XX', 'YY'):
+    if inp_pol == ("I",):
+        if out_pol == ("LL", "RR") or out_pol == ("XX", "YY"):
             new_shp = list(inp.shape)
             new_shp[0] = 2
             return jnp.broadcast_to(inp, new_shp)
-        if len(out_pol) == 1 and out_pol[0] in ('I', 'RR', 'LL', 'XX', 'yy'):
+        if len(out_pol) == 1 and out_pol[0] in ("I", "RR", "LL", "XX", "yy"):
             return inp
     err = f"conversion of polarization {inp_pol} to {out_pol} not implemented. Please implement!"
     raise NotImplementedError(err)
@@ -36,6 +37,7 @@ def InterferometryResponse(
     epsilon,
     nthreads=1,
     verbosity=0,
+    backend="ducc0",
 ):
     """Returns a function computing the radio interferometric response
 
@@ -45,6 +47,8 @@ def InterferometryResponse(
         The observation for which the response should compute model visibilities
     sky_domain_dict: dict
         A dictionary providing information about the discretization of the sky.
+    do_wgridding : bool
+        Whether to perform wgridding.
     epsilon: float
         The numerical accuracy with which to evaluate the response.
     nthreads: int, optional
@@ -52,39 +56,60 @@ def InterferometryResponse(
     verbosity: int, optional
         If set to 1 prints information about the setup and performance of the
         response.
+    backend : string
+        If `ducc0` use ducc0 wgridder. If `finufft` use finufft to compute response.
     """
-    npix_x = sky_domain_dict['npix_x']
-    npix_y = sky_domain_dict['npix_y']
-    pixsize_x = sky_domain_dict['pixsize_x']
-    pixsize_y = sky_domain_dict['pixsize_y']
+    if do_wgridding and backend == "finufft":
+        raise RuntimeError("Cannot do wgridding with backend finufft.")
+
+    npix_x = sky_domain_dict["npix_x"]
+    npix_y = sky_domain_dict["npix_y"]
+    pixsize_x = sky_domain_dict["pixsize_x"]
+    pixsize_y = sky_domain_dict["pixsize_y"]
 
-    n_pol = len(sky_domain_dict['pol_labels'])
+    n_pol = len(sky_domain_dict["pol_labels"])
 
     # compute bins for time and freq
-    n_times = len(sky_domain_dict['times'])
-    bb_times = get_binbounds(sky_domain_dict['times'])
+    n_times = len(sky_domain_dict["times"])
+    bb_times = get_binbounds(sky_domain_dict["times"])
 
-    n_freqs = len(sky_domain_dict['freqs'])
-    bb_freqs = get_binbounds(sky_domain_dict['freqs'])
+    n_freqs = len(sky_domain_dict["freqs"])
+    bb_freqs = get_binbounds(sky_domain_dict["freqs"])
 
     # build responses for: time binds, freq bins
     sr = []
     row_indices, freq_indices = [], []
     for t in range(n_times):
         sr_tmp, t_tmp, f_tmp = [], [], []
-        if tuple(bb_times[t:t+2]) == (-np.inf, np.inf):
+        if tuple(bb_times[t : t + 2]) == (-np.inf, np.inf):
             oo = observation
             tind = slice(None)
         else:
-            oo, tind = observation.restrict_by_time(bb_times[t], bb_times[t+1], True)
+            oo, tind = observation.restrict_by_time(bb_times[t], bb_times[t + 1], True)
         for f in range(n_freqs):
-            ooo, find = oo.restrict_by_freq(bb_freqs[f], bb_freqs[f+1], True)
+            ooo, find = oo.restrict_by_freq(bb_freqs[f], bb_freqs[f + 1], True)
             if any(np.array(ooo.vis.shape) == 0):
                 rrr = None
             else:
-                rrr = InterferometryResponseDucc(ooo, npix_x, npix_y, pixsize_x,
-                                                pixsize_y, do_wgridding, epsilon,
-                                                nthreads, verbosity)
+                if backend == "ducc0":
+                    rrr = InterferometryResponseDucc(
+                        ooo,
+                        npix_x,
+                        npix_y,
+                        pixsize_x,
+                        pixsize_y,
+                        do_wgridding,
+                        epsilon,
+                        nthreads,
+                        verbosity,
+                    )
+                elif backend == "finufft":
+                    rrr = InterferometryResponseFinuFFT(
+                        ooo, pixsize_x, pixsize_y, epsilon
+                    )
+                else:
+                    err = f"backend must be `ducc0` or `finufft` not {backend}"
+                    raise ValueError(err)
 
             sr_tmp.append(rrr)
             t_tmp.append(tind)
@@ -93,18 +118,18 @@ def InterferometryResponse(
         row_indices.append(t_tmp)
         freq_indices.append(f_tmp)
 
-
-    target_shape = (n_pol, ) + tuple(observation.vis.shape[1:])
+    target_shape = (n_pol,) + tuple(observation.vis.shape[1:])
     foo = np.zeros(target_shape, np.int8)
     for pp in range(n_pol):
         for tt in range(n_times):
             for ff in range(n_freqs):
-                foo[pp, row_indices[tt][ff], freq_indices[tt][ff]] = 1.
+                foo[pp, row_indices[tt][ff], freq_indices[tt][ff]] = 1.0
     if np.any(foo == 0):
         raise RuntimeError("This should not happen. Please report.")
 
-    inp_pol = tuple(sky_domain_dict['pol_labels'])
+    inp_pol = tuple(sky_domain_dict["pol_labels"])
     out_pol = observation.vis.domain[0].labels
+
     def apply_R(sky):
         res = jnp.empty(target_shape, dtype_float2complex(sky.dtype))
         for pp in range(sky.shape[0]):
@@ -120,33 +145,6 @@ def InterferometryResponse(
 
     return apply_R
 
-def InterferometryResponseOld(
-    observation, domain, do_wgridding, epsilon, verbosity=0, nthreads=1
-):
-    import jax_linop
-    from ..response import InterferometryResponse
-
-    R_old = InterferometryResponse(
-        observation, domain, do_wgridding, epsilon, verbosity, nthreads
-    )
-
-    def R(inp, out, state):
-        inp = ift.makeField(R_old.domain, inp)
-        out[()] = R_old(inp).val
-
-    def Re_T(inp, out, state):
-        inp = ift.makeField(R_old.target, inp.conj())
-        out[()] = R_old.adjoint(inp).val.conj()
-
-    def R_abstract(shape, dtype, state):
-        return R_old.target.shape, np.dtype(np.complex128)
-
-    def R_abstract_T(shape, dtype, state):
-        return R_old.domain.shape, np.dtype(np.float64)
-
-    R_jax = jax_linop.get_linear_call(R, Re_T, R_abstract, R_abstract_T)
-    return lambda x: R_jax(x)[0]
-
 
 def InterferometryResponseDucc(
     observation,
@@ -195,7 +193,7 @@ def InterferometryResponseFinuFFT(observation, pixsizex, pixsizey, epsilon):
 
     def apply_finufft(inp, u, v, eps):
         res = vol * nufft2(inp.astype(np.complex128), u, v, eps=eps)
-        return jnp.expand_dims(res.reshape(-1, len(freq)), 0)
+        return res.reshape(-1, len(freq))
 
-    R = partial(apply_finufft, u=u_finu, v=v_finu, eps=epsilon)
+    R = Partial(apply_finufft, u=u_finu, v=v_finu, eps=epsilon)
     return R