Commit fa9303d5 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

FFTSmoothingOperator does not accept logarithmic distances any more

parent f64c8f54
Pipeline #16821 passed with stage
in 14 minutes and 17 seconds
...@@ -8,12 +8,12 @@ from .smoothing_operator import SmoothingOperator ...@@ -8,12 +8,12 @@ from .smoothing_operator import SmoothingOperator
class FFTSmoothingOperator(SmoothingOperator): class FFTSmoothingOperator(SmoothingOperator):
def __init__(self, domain, sigma, log_distances=False, def __init__(self, domain, sigma,
default_spaces=None): default_spaces=None):
super(FFTSmoothingOperator, self).__init__( super(FFTSmoothingOperator, self).__init__(
domain=domain, domain=domain,
sigma=sigma, sigma=sigma,
log_distances=log_distances, log_distances=False,
default_spaces=default_spaces) default_spaces=default_spaces)
self._transformator_cache = {} self._transformator_cache = {}
...@@ -32,10 +32,6 @@ class FFTSmoothingOperator(SmoothingOperator): ...@@ -32,10 +32,6 @@ class FFTSmoothingOperator(SmoothingOperator):
kernel = codomain.get_distance_array( kernel = codomain.get_distance_array(
distribution_strategy=axes_local_distribution_strategy) distribution_strategy=axes_local_distribution_strategy)
#MR FIXME: this causes calls of log(0.) which should probably be avoided
if self.log_distances:
kernel.apply_scalar_function(np.log, inplace=True)
kernel.apply_scalar_function( kernel.apply_scalar_function(
codomain.get_fft_smoothing_kernel_function(self.sigma), codomain.get_fft_smoothing_kernel_function(self.sigma),
inplace=True) inplace=True)
...@@ -52,7 +48,8 @@ class FFTSmoothingOperator(SmoothingOperator): ...@@ -52,7 +48,8 @@ class FFTSmoothingOperator(SmoothingOperator):
# apply the kernel # apply the kernel
if inverse: if inverse:
#MR FIXME: danger of having division by zero or overflows # avoid zeroes in the kernel to work around divisions by zero
local_kernel = np.maximum(1e-12,local_kernel)
local_transformed_x /= local_kernel local_transformed_x /= local_kernel
else: else:
local_transformed_x *= local_kernel local_transformed_x *= local_kernel
......
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
import unittest import unittest
import numpy as np import numpy as np
from numpy.testing import assert_equal, assert_approx_equal,\ from numpy.testing import assert_equal, assert_allclose
assert_allclose
from nifty import Field,\ from nifty import Field,\
RGSpace,\ RGSpace,\
...@@ -39,10 +38,9 @@ def _get_rtol(tp): ...@@ -39,10 +38,9 @@ def _get_rtol(tp):
class SmoothingOperator_Tests(unittest.TestCase): class SmoothingOperator_Tests(unittest.TestCase):
spaces = [RGSpace(128)] spaces = [RGSpace(128)]
@expand(product(spaces, [0., .5, 5.], [True, False])) @expand(product(spaces, [0., .5, 5.]))
def test_property(self, space, sigma, log_distances): def test_property(self, space, sigma):
op = SmoothingOperator(space, sigma=sigma, op = SmoothingOperator(space, sigma=sigma)
log_distances=log_distances)
if op.domain[0] != space: if op.domain[0] != space:
raise TypeError raise TypeError
if op.unitary != False: if op.unitary != False:
...@@ -51,37 +49,34 @@ class SmoothingOperator_Tests(unittest.TestCase): ...@@ -51,37 +49,34 @@ class SmoothingOperator_Tests(unittest.TestCase):
raise ValueError raise ValueError
if op.sigma != sigma: if op.sigma != sigma:
raise ValueError raise ValueError
if op.log_distances != log_distances: if op.log_distances != False:
raise ValueError raise ValueError
@expand(product(spaces, [0., .5, 5.], [True, False])) @expand(product(spaces, [0., .5, 5.]))
def test_adjoint_times(self, space, sigma, log_distances): def test_adjoint_times(self, space, sigma):
op = SmoothingOperator(space, sigma=sigma, op = SmoothingOperator(space, sigma=sigma)
log_distances=log_distances)
rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space)
rand2 = Field.from_random('normal', domain=space) rand2 = Field.from_random('normal', domain=space)
tt1 = rand1.vdot(op.times(rand2)) tt1 = rand1.vdot(op.times(rand2))
tt2 = rand2.vdot(op.adjoint_times(rand1)) tt2 = rand2.vdot(op.adjoint_times(rand1))
assert_approx_equal(tt1, tt2) assert_allclose(tt1, tt2)
@expand(product(spaces, [0., .5, 5.], [False])) @expand(product(spaces, [0., .5, 5.]))
def test_times(self, space, sigma, log_distances): def test_times(self, space, sigma):
op = SmoothingOperator(space, sigma=sigma, op = SmoothingOperator(space, sigma=sigma)
log_distances=log_distances)
rand1 = Field(space, val=0.) rand1 = Field(space, val=0.)
rand1.val[0] = 1. rand1.val[0] = 1.
tt1 = op.times(rand1) tt1 = op.times(rand1)
assert_approx_equal(1, tt1.sum()) assert_allclose(1, tt1.sum())
@expand(product(spaces, [0., .5, 5.], [True, False])) @expand(product(spaces, [0., .5, 5.]))
def test_inverse_adjoint_times(self, space, sigma, log_distances): def test_inverse_adjoint_times(self, space, sigma):
op = SmoothingOperator(space, sigma=sigma, op = SmoothingOperator(space, sigma=sigma)
log_distances=log_distances)
rand1 = Field.from_random('normal', domain=space) rand1 = Field.from_random('normal', domain=space)
rand2 = Field.from_random('normal', domain=space) rand2 = Field.from_random('normal', domain=space)
tt1 = rand1.vdot(op.inverse_times(rand2)) tt1 = rand1.vdot(op.inverse_times(rand2))
tt2 = rand2.vdot(op.inverse_adjoint_times(rand1)) tt2 = rand2.vdot(op.inverse_adjoint_times(rand1))
assert_approx_equal(tt1, tt2) assert_allclose(tt1, tt2)
@expand(product([128, 256], [1, 0.4], [0., 1., 3.7], @expand(product([128, 256], [1, 0.4], [0., 1., 3.7],
[np.float64, np.complex128])) [np.float64, np.complex128]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment