Commit f37c42e8 authored by Lukas Platz's avatar Lukas Platz
Browse files

FuncConvolutionOperator: apply convolution with mean subtracted

parent a6d80b97
Pipeline #48043 passed with stages
in 8 minutes and 18 seconds
......@@ -30,7 +30,7 @@ from ..field import Field
from .. import utilities
def FuncConvolutionOperator(domain, func, space=None):
def FuncConvolutionOperator(domain, func, space=None, without_mean=False):
"""Convolves input with a radially symmetric kernel defined by `func`
......@@ -45,6 +45,9 @@ def FuncConvolutionOperator(domain, func, space=None):
The index of the subdomain on which the operator should act
If None, it is set to 0 if `domain` contains exactly one space.
`domain[space]` must be of type `RGSpace`, `HPSpace`, or `GLSpace`.
without_mean: bool, optional
If `True`, subtracts the input mean before applying the convolution
and adds it back afterwards.
......@@ -59,10 +62,10 @@ def FuncConvolutionOperator(domain, func, space=None):
raise TypeError("unsupported domain")
codomain = domain[space].get_default_codomain()
kernel = codomain.get_conv_kernel_from_func(func)
return _ConvolutionOperator(domain, kernel, space)
return _ConvolutionOperator(domain, kernel, space, without_mean)
def _ConvolutionOperator(domain, kernel, space=None):
def _ConvolutionOperator(domain, kernel, space=None, without_mean=False):
domain = DomainTuple.make(domain)
space = utilities.infer_space(domain, space)
if len(kernel.domain) != 1:
......@@ -77,4 +80,28 @@ def _ConvolutionOperator(domain, kernel, space=None):
HT = HarmonicTransformOperator(lm, domain[space], space)
diag = DiagonalOperator(kernel*domain[space].total_volume, lm, (space,))
wgt = WeightApplier(domain, space, 1)
return HT(diag(HT.adjoint(wgt)))
op = HT(diag(HT.adjoint(wgt)))
if without_mean:
return _ApplicationWithoutMeanOperator(domain, op)
return op
class _ApplicationWithoutMeanOperator(EndomorphicOperator):
def __init__(self, domain, op):
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = domain
self._op = op
if (op.domain != domain) or (op.domain !=
raise TypeError("domains incompatible")
def apply(self, x, mode):
self._check_input(x, mode)
mean = x.mean()
return mean + self._op.apply(x - mean, mode)
def __repr__(self):
from ..utilities import indent
return "\n".join((
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment