Commit cc4b112f authored by Vincent Eberle's avatar Vincent Eberle
Browse files

put FinuFFT into FFTInterpolator

parent 7c9370d6
Pipeline #93493 failed with stages
in 32 seconds
......@@ -81,6 +81,8 @@ class FinuFFT(LinearOperator):
self._capability = self.TIMES | self.ADJOINT_TIMES
self._target = makeDomain(target)
self._domain = DomainTuple.make(UnstructuredDomain((pos.shape[0])))
if pos.ndim != 2:
raise ValueError("sampling_points must be a 2D array")
pos = (pos*self._target[0].distances) * 2*np.pi % (2*np.pi)
self._eps = float(eps/10) # @ TODO Philipp, how do you know?
if pos.ndim > 1:
......
......@@ -18,7 +18,7 @@
import numpy as np
from ..library.gridder import Gridder
from ..library.gridder import FinuFFT
from ..sugar import makeDomain, makeField
from .harmonic_operators import HartleyOperator
from ..domains.rg_space import RGSpace
......@@ -26,7 +26,7 @@ from .linear_operator import LinearOperator
class FFTInterpolator(LinearOperator):
"""FFT Interpolation using Gridder and HartleyOperator
"""FFT Interpolation using FinuFFT and HartleyOperator
Parameters
---------
......@@ -38,41 +38,35 @@ class FFTInterpolator(LinearOperator):
Notes
----
#FIXME Documentation from Philipp ? PBCs? / Torus?
#NOTE implement Switch for PhiNuFFT?
#FIXME Documentation from Philipp
"""
def __init__(self, domain, pos, eps=2e-10, nthreads=1):
def __init__(self, domain, pos, eps=2e-10):
self._domain = makeDomain(domain)
if not isinstance(pos, np.ndarray):
raise TypeError("sampling_points need to be a numpy.ndarray")
if pos.ndim != 2:
raise ValueError("sampling_points must be a 2D array")
if pos.shape[0] != 2:
raise ValueError("first dimension of sampling_points must have length 2")
for ii in [0, 1]:
if domain.shape[ii] % 2 != 0:
raise ValueError("even number of samples is required for gridding operation")
dist = [list(dom.distances) for dom in self.domain]
dist = np.array(dist).reshape(-1, 1)
pos = pos / dist
gridderdom = RGSpace(self.domain.shape)
self._gridder = Gridder(gridderdom, pos.T, eps, nthreads)
finudom = RGSpace(self.domain.shape)
self._finu = FinuFFT(finudom, pos.T, eps)
self._ht = HartleyOperator(self._domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._target = self._gridder.domain
self._target = self._finu.domain
def apply(self, x, mode):
self._check_input(x, mode)
ht = self._ht
gridder = self._gridder
finu = self._finu
nx, ny = ht.target.shape
if mode == self.TIMES:
x = ht(x)
x = makeField(gridder.target, np.fft.fftshift(x.val))
x = gridder.adjoint(x)
x = makeField(finu.target, np.fft.fftshift(x.val))
x = finu.adjoint(x)
x = x.real + x.imag
else:
x = gridder(x + 1j*x)
x = makeField(ht.target, np.fft.fftshift(x.val))
x = finu(x + 1j*x)
x = makeField(ht.target, np.fft.ifftshift(x.val))
x = ht.adjoint(x)
return x/self.domain.total_volume()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment