Commit fe3bc980 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak FFTs

parent 876fde66
Pipeline #17907 passed with stage
in 3 minutes and 22 seconds
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import object, range
class SerialFFT(object):
"""
The pyfftw pendant of a fft object.
"""
def __init__(self, domain, codomain):
import pyfftw
self.domain = domain
self.codomain = codomain
pyfftw.interfaces.cache.enable()
def transform(self, val, axes):
"""
The scalar FFT transform function.
Parameters
----------
val : numpy.ndarray
The value-array of the field which is supposed to
be transformed.
axes: tuple, None
The axes which should be transformed.
Returns
-------
result : numpy.ndarray
Fourier-transformed pendant of the input field.
"""
from pyfftw.interfaces.numpy_fft import fftn, ifftn
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("Provided axes does not match array shape")
if self.codomain.harmonic:
return fftn(val, axes=axes)
else:
return ifftn(val, axes=axes)
......@@ -19,7 +19,6 @@
from __future__ import division
import numpy as np
from .transformation import Transformation
from .rg_transforms import SerialFFT
class RGRGTransformation(Transformation):
......@@ -27,8 +26,10 @@ class RGRGTransformation(Transformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None):
import pyfftw
super(RGRGTransformation, self).__init__(domain, codomain)
self._transform = SerialFFT(self.domain, self.codomain)
pyfftw.interfaces.cache.enable()
self._fwd = self.codomain.harmonic
# ---Mandatory properties and methods---
......@@ -36,6 +37,16 @@ class RGRGTransformation(Transformation):
def unitary(self):
return True
def _transform_helper(self, val, axes):
from pyfftw.interfaces.numpy_fft import fftn, ifftn
# Check if the axes provided are valid given the shape
if axes is not None and \
not all(axis in range(len(val.shape)) for axis in axes):
raise ValueError("Provided axes does not match array shape")
return fftn(val, axes=axes) if self._fwd else ifftn(val, axes=axes)
def transform(self, val, axes=None):
"""
RG -> RG transform method.
......@@ -50,18 +61,18 @@ class RGRGTransformation(Transformation):
"""
fct=1.
if self._transform.codomain.harmonic:
if self.codomain.harmonic:
# correct for forward fft.
# naively one would set power to 0.5 here in order to
# apply effectively a factor of 1/sqrt(N) to the field.
# BUT: the pixel volumes of the domain and codomain are different.
# Hence, in order to produce the same scalar product, power===1.
fct *= self._transform.domain.weight()
fct *= self.domain.weight()
# Perform the transformation
if issubclass(val.dtype.type, np.complexfloating):
Tval_real = self._transform.transform(val.real, axes)
Tval_imag = self._transform.transform(val.imag, axes)
Tval_real = self._transform_helper(val.real, axes)
Tval_imag = self._transform_helper(val.imag, axes)
if self.codomain.harmonic:
Tval_real.real += Tval_real.imag
Tval_real.imag = Tval_imag.real + Tval_imag.imag
......@@ -71,17 +82,17 @@ class RGRGTransformation(Transformation):
Tval = Tval_real
else:
Tval = self._transform.transform(val, axes)
Tval = self._transform_helper(val, axes)
if self.codomain.harmonic:
Tval.real += Tval.imag
else:
Tval.real -= Tval.imag
Tval = Tval.real
if not self._transform.codomain.harmonic:
if not self.codomain.harmonic:
# correct for inverse fft.
# See discussion above.
fct /= self._transform.codomain.weight()
fct /= self.codomain.weight()
Tval *= fct
return Tval
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment