Commit 1d6e850e authored by Philipp Arras's avatar Philipp Arras
Browse files

Simplify

parent fe355625
......@@ -87,8 +87,7 @@ from .library.light_cone_operator import LightConeOperator
from .library.wiener_filter_curvature import WienerFilterCurvature
from .library.adjust_variances import (make_adjust_variances_hamiltonian,
do_adjust_variances)
from .library.nft import Gridder
from .library.nft import FinuFFT
from .library.nft import Gridder, FinuFFT
from .library.correlated_fields import CorrelatedFieldMaker
from .library.correlated_fields_simple import SimpleCorrelatedField
......
......@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from scipy.constants import speed_of_light
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
......@@ -30,7 +31,7 @@ class Gridder(LinearOperator):
self._target = makeDomain(target)
for ii in [0, 1]:
if target.shape[ii] % 2 != 0:
raise ValueError("even number of samples is required for gridding operation")
raise ValueError("even number of pixels is required for gridding operation")
if (len(self._target) != 1 or not isinstance(self._target[0], RGSpace)
or not len(self._target.shape) == 2):
raise ValueError("need target with exactly one 2D RGSpace")
......@@ -46,11 +47,9 @@ class Gridder(LinearOperator):
self._eps = float(eps)
self._nthreads = int(nthreads)
def apply(self, x, mode):
self._check_input(x, mode)
speedOfLight = 299792458.
freq = np.array([speedOfLight])
freq = np.array([speed_of_light])
x = x.val
nxdirty, nydirty = self._target[0].shape
dstx, dsty = self._target[0].distances
......@@ -68,7 +67,7 @@ class Gridder(LinearOperator):
class FinuFFT(LinearOperator):
"""
Operator computing nonuniformian FFTs using finufft package
Operator computing non-uniform FFTs using finufft package
Parameters
----------
......@@ -78,32 +77,29 @@ class FinuFFT(LinearOperator):
"""
def __init__(self, target, pos, eps=2e-10):
import finufft
self._capability = self.TIMES | self.ADJOINT_TIMES
self._target = makeDomain(target)
self._domain = DomainTuple.make(UnstructuredDomain((pos.shape[0])))
pos = (pos*self._target[0].distances) * 2*np.pi % (2*np.pi)
self._eps = float(eps/10) # @ TODO Philipp, how do you know?
self._eps = float(eps)
dst = np.array(self._target[0].distances)
pos = (2*np.pi*pos*dst) % (2*np.pi)
self._eps = float(eps/10)
if pos.ndim > 1:
self._pos = [pos[:, k] for k in range(pos.shape[1])]
self._method_strings = ('nufft' + str(pos.shape[1]) + 'd1',
'nufft' + str(pos.shape[1]) + 'd2')
s = 'nufft' + str(pos.shape[1]) + 'd'
else:
self._pos = [pos]
self._method_strings = ('nufft1d1' , 'nufft1d2')
s = 'nufft1d'
self._f = getattr(finufft, s+'1')
self._fadj = getattr(finufft, s+'2')
def apply(self, x, mode):
self._check_input(x,mode)
x = x.val
import finufft
self._check_input(x, mode)
if mode == self.TIMES:
x = np.copy(x)
res = getattr(finufft, self._method_strings[0])(*self._pos,
c=x,
n_modes=self._target[0].shape,
eps= self._eps).real
#TODO is this .real needed?
res = self._f(*self._pos, c=x.val_rw(),
n_modes=self._target[0].shape, eps=self._eps).real
# TODO is this .real needed?
if mode == self.ADJOINT_TIMES:
res = getattr(finufft, self._method_strings[1])(*self._pos,
f=x,
eps= self._eps)
res = self._fadj(*self._pos, f=x.val, eps=self._eps)
return makeField(self._tgt(mode), res)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment