Commit e8f33125 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak FFT operator

parent 922878ed
Pipeline #23731 failed with stage
in 4 minutes and 4 seconds
......@@ -73,6 +73,7 @@ if __name__ == "__main__":
j = R_harmonic.adjoint_times(N.inverse_times(data))
print "xx",j.val[0]*nu.K*(nu.m**dimensionality)
exit()
ctrl = ift.GradientNormController(
verbose=True, tol_abs_gradnorm=1e-40/(nu.K*(nu.m**dimensionality)))
inverter = ift.ConjugateGradient(controller=ctrl)
......
......@@ -21,18 +21,26 @@ from .. import DomainTuple
from ..spaces import RGSpace
from ..utilities import infer_space
from .linear_operator import LinearOperator
from .fft_operator_support import RGRGTransformation, SphericalTransformation
from .. import dobj
from .. import utilities
from ..field import Field
from ..spaces.gl_space import GLSpace
class FFTOperator(LinearOperator):
"""Transforms between a pair of harmonic and position domains.
"""Transforms between a pair of position and harmonic domains.
Built-in domain pairs are
- harmonic RGSpace / nonharmonic RGSpace (with matching distances)
- LMSpace / HPSpace
- LMSpace / GLSpace
The times() operation always transforms from the harmonic to the
position domain.
- a harmonic and a non-harmonic RGSpace (with matching distances)
- a HPSpace and a LMSpace
- a GLSpace and a LMSpace
Within a domain pair, both orderings are possible.
For RGSpaces, the operator provides the full set of operations.
For the sphere-related domains, it only supports the transform from
harmonic to position space and its adjoint; if the operator domain is
harmonic, this will be times() and adjoint_times(), otherwise
inverse_times() and adjoint_inverse_times()
Parameters
----------
......@@ -58,33 +66,158 @@ class FFTOperator(LinearOperator):
# Initialize domain and target
self._domain = DomainTuple.make(domain)
self._space = infer_space(self._domain, space)
if not self._domain[self._space].harmonic:
raise TypeError("H2POperator must work on a harmonic domain")
adom = self.domain[self._space]
adom = self._domain[self._space]
if target is None:
target = adom.get_default_codomain()
self._target = [dom for dom in self.domain]
self._target = [dom for dom in self._domain]
self._target[self._space] = target
self._target = DomainTuple.make(self._target)
adom.check_codomain(target)
target.check_codomain(adom)
hdom, pdom = (self._domain, self._target)
if isinstance(pdom[self._space], RGSpace):
self._trafo = RGRGTransformation(hdom, pdom, self._space)
if isinstance(adom, RGSpace):
self._applyfunc = self._apply_cartesian
self._capability = self._all_ops
import pyfftw
pyfftw.interfaces.cache.enable()
else:
self._trafo = SphericalTransformation(hdom, pdom, self._space)
from pyHealpix import sharpjob_d
self._applyfunc = self._apply_spherical
hspc = adom if adom.harmonic else target
pspc = target if adom.harmonic else adom
self.lmax=hspc.lmax
self.mmax=hspc.mmax
self.sjob = sharpjob_d()
self.sjob.set_triangular_alm_info(self.lmax, self.mmax)
if isinstance(pspc, GLSpace):
self.sjob.set_Gauss_geometry(pspc.nlat, pspc.nlon)
else:
self.sjob.set_Healpix_geometry(pspc.nside)
if adom.harmonic:
self._capability = self.TIMES | self.ADJOINT_TIMES
else:
self._capability = (self.INVERSE_TIMES |
self.INVERSE_ADJOINT_TIMES)
def apply(self, x, mode):
self._check_input(x, mode)
if np.issubdtype(x.dtype, np.complexfloating):
res = (self._trafo.apply(x.real, mode) +
1j * self._trafo.apply(x.imag, mode))
return (self._applyfunc(x.real, mode) +
1j*self._applyfunc(x.imag, mode))
else:
res = self._trafo.apply(x, mode)
return res
return self._applyfunc(x, mode)
def _apply_cartesian(self, x, mode):
"""
RG -> RG transform method.
Parameters
----------
x : Field
The field to be transformed
"""
from pyfftw.interfaces.numpy_fft import fftn
axes = x.domain.axes[self._space]
tdom = self._target if x.domain==self._domain else self._domain
oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed
ldat = dobj.local_data(x.val)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1:
# we can use one Hartley pass in between the redistributions
tmp = dobj.redistribute(x.val, nodist=axes)
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate, full FFTs needed
# ideal strategy for the moment would be:
# - do real-to-complex FFT on all local axes
# - fill up array
# - redistribute array
# - do complex-to-complex FFT on remaining axis
# - add re+im
# - redistribute back
rem_axes = tuple(i for i in axes if i != oldax)
tmp = x.val
ldat = dobj.local_data(tmp)
ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
if oldax != 0:
raise ValueError("bad distribution")
ldat2 = ldat.reshape((ldat.shape[0],
np.prod(ldat.shape[1:])))
shp2d = (x.val.shape[0], np.prod(x.val.shape[1:]))
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp)
ldat2 = fftn(ldat2, axes=(1,))
ldat2 = ldat2.real+ldat2.imag
tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
Tval = Field(tdom, tmp)
if x.domain[self._space].harmonic:
if (mode == LinearOperator.TIMES or
mode == LinearOperator.ADJOINT_TIMES):
fct = self._domain[self._space].scalar_dvol()
else:
fct = 1./(self._domain[self._space].scalar_dvol()*self._domain[self._space].dim)
else:
if (mode == LinearOperator.TIMES or
mode == LinearOperator.ADJOINT_TIMES):
fct = 1./(self._target[self._space].scalar_dvol()*self._target[self._space].dim)
else:
fct = self._target[self._space].scalar_dvol()
if fct != 1:
Tval *= fct
return Tval
def _slice_p2h(self, inp):
rr = self.sjob.alm2map_adjoint(inp)
assert len(rr) == ((self.mmax+1)*(self.mmax+2))//2 + \
(self.mmax+1)*(self.lmax-self.mmax)
res = np.empty(2*len(rr)-self.lmax-1, dtype=rr[0].real.dtype)
res[0:self.lmax+1] = rr[0:self.lmax+1].real
res[self.lmax+1::2] = np.sqrt(2)*rr[self.lmax+1:].real
res[self.lmax+2::2] = np.sqrt(2)*rr[self.lmax+1:].imag
return res/np.sqrt(np.pi*4)
def _slice_h2p(self, inp):
res = np.empty((len(inp)+self.lmax+1)//2, dtype=(inp[0]*1j).dtype)
assert len(res) == ((self.mmax+1)*(self.mmax+2))//2 + \
(self.mmax+1)*(self.lmax-self.mmax)
res[0:self.lmax+1] = inp[0:self.lmax+1]
res[self.lmax+1:] = np.sqrt(0.5)*(inp[self.lmax+1::2] +
1j*inp[self.lmax+2::2])
res = self.sjob.alm2map(res)
return res/np.sqrt(np.pi*4)
def _apply_spherical(self, x, mode):
axes = x.domain.axes[self._space]
axis = axes[0]
tval = x.val
if dobj.distaxis(tval) == axis:
tval = dobj.redistribute(tval, nodist=(axis,))
distaxis = dobj.distaxis(tval)
p2h = not x.domain[self._space].harmonic
tdom = self._target if x.domain==self._domain else self._domain
func = self._slice_p2h if p2h else self._slice_h2p
idat = dobj.local_data(tval)
odat = np.empty(dobj.local_shape(tdom.shape, distaxis=distaxis),
dtype=x.dtype)
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = func(idat[slice])
odat = dobj.from_local_data(tdom.shape, odat, distaxis)
if distaxis != dobj.distaxis(x.val):
odat = dobj.redistribute(odat, dist=dobj.distaxis(x.val))
return Field(tdom, odat)
@property
def domain(self):
......@@ -96,7 +229,4 @@ class FFTOperator(LinearOperator):
@property
def capability(self):
res = self.TIMES | self.ADJOINT_TIMES
if self._trafo.unitary:
res |= self.INVERSE_TIMES | self.ADJOINT_INVERSE_TIMES
return res
return self._capability
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
import numpy as np
from .. import utilities
from .. import dobj
from ..field import Field
from ..spaces.gl_space import GLSpace
from .linear_operator import LinearOperator
class Transformation(object):
def __init__(self, hdom, pdom, space):
self.hdom = hdom
self.pdom = pdom
self.space = space
class RGRGTransformation(Transformation):
def __init__(self, hdom, pdom, space):
import pyfftw
super(RGRGTransformation, self).__init__(hdom, pdom, space)
pyfftw.interfaces.cache.enable()
self.fct_noninverse = hdom[space].scalar_dvol()
self.fct_inverse = 1./(hdom[space].scalar_dvol()*hdom[space].dim)
@property
def unitary(self):
return True
def apply(self, x, mode):
"""
RG -> RG transform method.
Parameters
----------
x : Field
The field to be transformed
"""
from pyfftw.interfaces.numpy_fft import fftn
axes = x.domain.axes[self.space]
p2h = x.domain == self.pdom
tdom = self.hdom if p2h else self.pdom
oldax = dobj.distaxis(x.val)
if oldax not in axes: # straightforward, no redistribution needed
ldat = dobj.local_data(x.val)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(x.val.shape, ldat, distaxis=oldax)
elif len(axes) < len(x.shape) or len(axes) == 1:
# we can use one Hartley pass in between the redistributions
tmp = dobj.redistribute(x.val, nodist=axes)
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = utilities.hartley(ldat, axes=axes)
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else: # two separate, full FFTs needed
# ideal strategy for the moment would be:
# - do real-to-complex FFT on all local axes
# - fill up array
# - redistribute array
# - do complex-to-complex FFT on remaining axis
# - add re+im
# - redistribute back
if True:
rem_axes = tuple(i for i in axes if i != oldax)
tmp = x.val
ldat = dobj.local_data(tmp)
ldat = utilities.my_fftn_r2c(ldat, axes=rem_axes)
# new, experimental code
if True:
if oldax != 0:
raise ValueError("bad distribution")
ldat2 = ldat.reshape((ldat.shape[0],
np.prod(ldat.shape[1:])))
shp2d = (x.val.shape[0], np.prod(x.val.shape[1:]))
tmp = dobj.from_local_data(shp2d, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp)
ldat2 = fftn(ldat2, axes=(1,))
ldat2 = ldat2.real+ldat2.imag
tmp = dobj.from_local_data(tmp.shape, ldat2, distaxis=0)
tmp = dobj.transpose(tmp)
ldat2 = dobj.local_data(tmp).reshape(ldat.shape)
tmp = dobj.from_local_data(x.val.shape, ldat2, distaxis=0)
else:
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
tmp = dobj.redistribute(tmp, nodist=(oldax,))
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=(oldax,))
ldat = ldat.real+ldat.imag
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
else:
tmp = dobj.redistribute(x.val, nodist=(oldax,))
newax = dobj.distaxis(tmp)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=(oldax,))
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=newax)
tmp = dobj.redistribute(tmp, dist=oldax)
rem_axes = tuple(i for i in axes if i != oldax)
ldat = dobj.local_data(tmp)
ldat = fftn(ldat, axes=rem_axes)
ldat = ldat.real+ldat.imag
tmp = dobj.from_local_data(tmp.shape, ldat, distaxis=oldax)
Tval = Field(tdom, tmp)
if (mode == LinearOperator.TIMES or
mode == LinearOperator.ADJOINT_TIMES):
fct = self.fct_noninverse
else:
fct = self.fct_inverse
if fct != 1:
Tval *= fct
return Tval
class SphericalTransformation(Transformation):
def __init__(self, hdom, pdom, space):
super(SphericalTransformation, self).__init__(hdom, pdom, space)
from pyHealpix import sharpjob_d
self.lmax = self.hdom[self.space].lmax
self.mmax = self.hdom[self.space].mmax
self.sjob = sharpjob_d()
self.sjob.set_triangular_alm_info(self.lmax, self.mmax)
if isinstance(self.pdom[self.space], GLSpace):
self.sjob.set_Gauss_geometry(self.pdom[self.space].nlat,
self.pdom[self.space].nlon)
else:
self.sjob.set_Healpix_geometry(self.pdom[self.space].nside)
@property
def unitary(self):
return False
def _slice_p2h(self, inp):
rr = self.sjob.alm2map_adjoint(inp)
assert len(rr) == ((self.mmax+1)*(self.mmax+2))//2 + \
(self.mmax+1)*(self.lmax-self.mmax)
res = np.empty(2*len(rr)-self.lmax-1, dtype=rr[0].real.dtype)
res[0:self.lmax+1] = rr[0:self.lmax+1].real
res[self.lmax+1::2] = np.sqrt(2)*rr[self.lmax+1:].real
res[self.lmax+2::2] = np.sqrt(2)*rr[self.lmax+1:].imag
return res/np.sqrt(np.pi*4)
def _slice_h2p(self, inp):
res = np.empty((len(inp)+self.lmax+1)//2, dtype=(inp[0]*1j).dtype)
assert len(res) == ((self.mmax+1)*(self.mmax+2))//2 + \
(self.mmax+1)*(self.lmax-self.mmax)
res[0:self.lmax+1] = inp[0:self.lmax+1]
res[self.lmax+1:] = np.sqrt(0.5)*(inp[self.lmax+1::2] +
1j*inp[self.lmax+2::2])
res = self.sjob.alm2map(res)
return res/np.sqrt(np.pi*4)
def apply(self, x, mode):
axes = x.domain.axes[self.space]
axis = axes[0]
tval = x.val
if dobj.distaxis(tval) == axis:
tval = dobj.redistribute(tval, nodist=(axis,))
distaxis = dobj.distaxis(tval)
p2h = x.domain == self.pdom
tdom = self.hdom if p2h else self.pdom
func = self._slice_p2h if p2h else self._slice_h2p
idat = dobj.local_data(tval)
odat = np.empty(dobj.local_shape(tdom.shape, distaxis=distaxis),
dtype=x.dtype)
for slice in utilities.get_slice_list(idat.shape, axes):
odat[slice] = func(idat[slice])
odat = dobj.from_local_data(tdom.shape, odat, distaxis)
if distaxis != dobj.distaxis(x.val):
odat = dobj.redistribute(odat, dist=dobj.distaxis(x.val))
return Field(tdom, odat)
......@@ -16,13 +16,12 @@ def FFTSmoothingOperator(domain, sigma, space=None):
space = infer_space(domain, space)
if domain[space].harmonic:
raise TypeError("domain must not be harmonic")
fftdom = list(domain)
codomain = domain[space].get_default_codomain()
fftdom[space] = codomain
fftdom = DomainTuple.make(fftdom)
FFT = FFTOperator(fftdom, domain[space], space=space)
FFT = FFTOperator(domain, space=space)
codomain = FFT.domain[space].get_default_codomain()
kernel = codomain.get_k_length_array()
smoother = codomain.get_fft_smoothing_kernel_function(sigma)
kernel = smoother(kernel)
diag = DiagonalOperator(kernel, fftdom, space)
return FFT*diag*FFT.inverse
ddom = list(domain)
ddom[space] = codomain
diag = DiagonalOperator(kernel, ddom, space)
return FFT.inverse*diag*FFT
......@@ -39,7 +39,5 @@ class Adjointness_Tests(unittest.TestCase):
@expand(product(_harmonic_spaces+_position_spaces,
[np.float64, np.complex128]))
def testFFT(self, sp, dtype):
if not sp.harmonic:
sp = sp.get_default_codomain()
op = ift.FFTOperator(sp)
_check_adjointness(op, dtype)
......@@ -37,8 +37,8 @@ class FFTOperatorTests(unittest.TestCase):
[np.float64, np.float32, np.complex64, np.complex128]))
def test_fft1D(self, dim1, d, itp):
tol = _get_rtol(itp)
b = ift.RGSpace(dim1, distances=d)
a = ift.RGSpace(dim1, distances=1./(dim1*d), harmonic=True)
a = ift.RGSpace(dim1, distances=d)
b = ift.RGSpace(dim1, distances=1./(dim1*d), harmonic=True)
fft = ift.FFTOperator(domain=a, target=b)
np.random.seed(16)
......@@ -53,8 +53,8 @@ class FFTOperatorTests(unittest.TestCase):
[np.float64, np.float32, np.complex64, np.complex128]))
def test_fft2D(self, dim1, dim2, d1, d2, itp):
tol = _get_rtol(itp)
b = ift.RGSpace([dim1, dim2], distances=[d1, d2])
a = ift.RGSpace([dim1, dim2],
a = ift.RGSpace([dim1, dim2], distances=[d1, d2])
b = ift.RGSpace([dim1, dim2],
distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True)
fft = ift.FFTOperator(domain=a, target=b)
......@@ -78,8 +78,8 @@ class FFTOperatorTests(unittest.TestCase):
assert_allclose(ift.dobj.to_global_data(inp.val),
ift.dobj.to_global_data(out.val), rtol=tol, atol=tol)
@expand(product([0, 3, 6, 11, 30],
[np.float64, np.float32, np.complex64, np.complex128]))
#@expand(product([0, 3, 6, 11, 30],
# [np.float64, np.float32, np.complex64, np.complex128]))
#def test_sht(self, lm, tp):
# tol = _get_rtol(tp)
# a = ift.LMSpace(lmax=lm)
......@@ -130,3 +130,15 @@ class FFTOperatorTests(unittest.TestCase):
v1 = np.sqrt(out.vdot(out))
v2 = np.sqrt(inp.vdot(fft.adjoint_times(out)))
assert_allclose(v1, v2, rtol=tol, atol=tol)
@expand(product([ift.RGSpace(128, distances=3.76, harmonic=True),
ift.LMSpace(lmax=30, mmax=25)],
[np.float64, np.float32, np.complex64, np.complex128]))
def test_normalisation(self, space, tp):
tol = 10 * _get_rtol(tp)
fft = ift.FFTOperator(space)
inp = ift.Field.from_random(domain=space, random_type='normal',
std=1, mean=2, dtype=tp)
out = fft.times(inp)
assert_allclose(ift.dobj.to_global_data(inp.val)[0], out.integrate(),
rtol=tol, atol=tol)
......@@ -9,18 +9,18 @@ class ResponseOperator_Tests(unittest.TestCase):
spaces = [ift.RGSpace(128), ift.GLSpace(nlat=37)]
@expand(product(spaces, [0., 5., 1.], [0., 1., .33]))
def test_property(self, space, sigma, exposure):
def test_property(self, space, sigma, sensitivity):
op = ift.ResponseOperator(space, sigma=[sigma],
exposure=[exposure])
sensitivity=[sensitivity])
if op.domain[0] != space:
raise TypeError
@expand(product(spaces, [0., 5., 1.], [0., 1., .33]))
def test_times_adjoint_times(self, space, sigma, exposure):
def test_times_adjoint_times(self, space, sigma, sensitivity):
if not isinstance(space, ift.RGSpace): # no smoothing supported
sigma = 0.
op = ift.ResponseOperator(space, sigma=[sigma],
exposure=[exposure])
sensitivity=[sensitivity])
rand1 = ift.Field.from_random('normal', domain=space)
rand2 = ift.Field.from_random('normal', domain=op.target[0])
tt1 = rand2.vdot(op.times(rand1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment