Commit 93b50d6e authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'fix_fft_normalization' into 'feature/field_multiple_space'

Fix normalization in FFTOperator



See merge request !24
parents 93d877c3 d131e95b
...@@ -277,6 +277,7 @@ class Field(object): ...@@ -277,6 +277,7 @@ class Field(object):
def power_synthesize(self): def power_synthesize(self):
# check that all spaces in self.domain are real or instances of power_space # check that all spaces in self.domain are real or instances of power_space
# check if field is real- or complex-valued # check if field is real- or complex-valued
pass
# ---Properties--- # ---Properties---
......
...@@ -26,8 +26,7 @@ class FFTOperator(LinearOperator): ...@@ -26,8 +26,7 @@ class FFTOperator(LinearOperator):
if target is None: if target is None:
target = utilities.get_default_codomain(self.domain[0]) target = utilities.get_default_codomain(self.domain[0])
self._target = self._parse_domain( self._target = self._parse_domain(target)
utilities.get_default_codomain(self.domain[0]))
self._forward_transformation = TransformationFactory.create( self._forward_transformation = TransformationFactory.create(
self.domain[0], self.target[0] self.domain[0], self.target[0]
...@@ -39,8 +38,11 @@ class FFTOperator(LinearOperator): ...@@ -39,8 +38,11 @@ class FFTOperator(LinearOperator):
def _times(self, x, spaces, types): def _times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
axes = None
else:
axes = x.domain_axes[spaces[0]]
axes = x.domain_axes[spaces[0]]
new_val = self._forward_transformation.transform(x.val, axes=axes) new_val = self._forward_transformation.transform(x.val, axes=axes)
if spaces is None: if spaces is None:
...@@ -56,6 +58,10 @@ class FFTOperator(LinearOperator): ...@@ -56,6 +58,10 @@ class FFTOperator(LinearOperator):
def _inverse_times(self, x, spaces, types): def _inverse_times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
if spaces is None:
axes = None
else:
axes = x.domain_axes[spaces[0]]
axes = x.domain_axes[spaces[0]] axes = x.domain_axes[spaces[0]]
new_val = self._inverse_transformation.transform(x.val, axes=axes) new_val = self._inverse_transformation.transform(x.val, axes=axes)
......
...@@ -549,10 +549,9 @@ class GFFT(Transform): ...@@ -549,10 +549,9 @@ class GFFT(Transform):
""" """
def __init__(self, domain, codomain, fft_module): def __init__(self, domain, codomain, fft_module=None):
if fft_module is None: if fft_module is None:
# gdi cannot find the required module fft_module = gdi['gfft_dummy']
raise ImportError("ERROR: GFFT module is not available.")
self.domain = domain self.domain = domain
self.codomain = codomain self.codomain = codomain
...@@ -610,15 +609,10 @@ class GFFT(Transform): ...@@ -610,15 +609,10 @@ class GFFT(Transform):
in_ax=[], in_ax=[],
out_ax=[], out_ax=[],
ftmachine='fft' if self.codomain.harmonic else 'ifft', ftmachine='fft' if self.codomain.harmonic else 'ifft',
in_zero_center=map( in_zero_center=map(bool, self.domain.zerocenter),
bool, self.domain.zerocenter out_zero_center=map(bool, self.codomain.zerocenter),
), # enforce_hermitian_symmetry=bool(self.codomain.complexity),
out_zero_center=map( enforce_hermitian_symmetry=False,
bool, self.codomain.zerocenter
),
enforce_hermitian_symmetry=bool(
self.codomain.complexity
),
W=-1, W=-1,
alpha=-1, alpha=-1,
verbose=False verbose=False
......
...@@ -125,14 +125,19 @@ class RGRGTransformation(Transformation): ...@@ -125,14 +125,19 @@ class RGRGTransformation(Transformation):
""" """
if self._transform.codomain.harmonic: if self._transform.codomain.harmonic:
# correct for forward fft # correct for forward fft.
# naively one would set power to 0.5 here in order to
# apply effectively a factor of 1/sqrt(N) to the field.
# BUT: the pixel volumes of the domain and codomain are different.
# Hence, in order to produce the same scalar product, power===1.
val = self._transform.domain.weight(val, power=1, axes=axes) val = self._transform.domain.weight(val, power=1, axes=axes)
# Perform the transformation # Perform the transformation
Tval = self._transform.transform(val, axes, **kwargs) Tval = self._transform.transform(val, axes, **kwargs)
if not self._transform.codomain.harmonic: if not self._transform.codomain.harmonic:
# correct for inverse fft # correct for inverse fft.
# See discussion above.
Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes) Tval = self._transform.codomain.weight(Tval, power=-1, axes=axes)
return Tval return Tval
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment