Commit 3b36c8a3 authored by Lukas Platz's avatar Lukas Platz

unconditionally convolve without signal mean

parent 1439a2b8
...@@ -30,7 +30,7 @@ from ..field import Field ...@@ -30,7 +30,7 @@ from ..field import Field
from .. import utilities from .. import utilities
def FuncConvolutionOperator(domain, func, space=None, without_mean=None): def FuncConvolutionOperator(domain, func, space=None):
"""Convolves input with a radially symmetric kernel defined by `func` """Convolves input with a radially symmetric kernel defined by `func`
Parameters Parameters
...@@ -65,15 +65,10 @@ def FuncConvolutionOperator(domain, func, space=None, without_mean=None): ...@@ -65,15 +65,10 @@ def FuncConvolutionOperator(domain, func, space=None, without_mean=None):
raise TypeError("unsupported domain") raise TypeError("unsupported domain")
codomain = domain[space].get_default_codomain() codomain = domain[space].get_default_codomain()
kernel = codomain.get_conv_kernel_from_func(func) kernel = codomain.get_conv_kernel_from_func(func)
if without_mean == None: return _ConvolutionOperator(domain, kernel, space)
if isinstance(domain[space], (HPSpace, GLSpace)):
without_mean = True
else:
without_mean = False
return _ConvolutionOperator(domain, kernel, space, without_mean)
def _ConvolutionOperator(domain, kernel, space=None, without_mean=False): def _ConvolutionOperator(domain, kernel, space=None):
domain = DomainTuple.make(domain) domain = DomainTuple.make(domain)
space = utilities.infer_space(domain, space) space = utilities.infer_space(domain, space)
if len(kernel.domain) != 1: if len(kernel.domain) != 1:
...@@ -89,10 +84,7 @@ def _ConvolutionOperator(domain, kernel, space=None, without_mean=False): ...@@ -89,10 +84,7 @@ def _ConvolutionOperator(domain, kernel, space=None, without_mean=False):
diag = DiagonalOperator(kernel*domain[space].total_volume, lm, (space,)) diag = DiagonalOperator(kernel*domain[space].total_volume, lm, (space,))
wgt = WeightApplier(domain, space, 1) wgt = WeightApplier(domain, space, 1)
op = HT(diag(HT.adjoint(wgt))) op = HT(diag(HT.adjoint(wgt)))
if without_mean: return _ApplicationWithoutMeanOperator(op)
return _ApplicationWithoutMeanOperator(op)
else:
return op
class _ApplicationWithoutMeanOperator(EndomorphicOperator): class _ApplicationWithoutMeanOperator(EndomorphicOperator):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment