Commits (2)
......@@ -30,7 +30,7 @@ from ..field import Field
from .. import utilities
def FuncConvolutionOperator(domain, func, space=None):
def FuncConvolutionOperator(domain, func, space=None, without_mean=None):
"""Convolves input with a radially symmetric kernel defined by `func`
......@@ -45,6 +45,12 @@ def FuncConvolutionOperator(domain, func, space=None):
The index of the subdomain on which the operator should act
If None, it is set to 0 if `domain` contains exactly one space.
`domain[space]` must be of type `RGSpace`, `HPSpace`, or `GLSpace`.
without_mean: bool, optional
If `None`, chooses domain-dependant default value:
- `True` for spherical domains (`HPSpace`, `GLSpace`)
- `False` for RGSpaces.
If `True`, subtracts the input mean before applying the convolution
and adds it back afterwards.
......@@ -59,10 +65,15 @@ def FuncConvolutionOperator(domain, func, space=None):
raise TypeError("unsupported domain")
codomain = domain[space].get_default_codomain()
kernel = codomain.get_conv_kernel_from_func(func)
return _ConvolutionOperator(domain, kernel, space)
if without_mean == None:
if isinstance(domain[space], (HPSpace, GLSpace)):
without_mean = True
without_mean = False
return _ConvolutionOperator(domain, kernel, space, without_mean)
def _ConvolutionOperator(domain, kernel, space=None):
def _ConvolutionOperator(domain, kernel, space=None, without_mean=False):
domain = DomainTuple.make(domain)
space = utilities.infer_space(domain, space)
if len(kernel.domain) != 1:
......@@ -77,4 +88,28 @@ def _ConvolutionOperator(domain, kernel, space=None):
HT = HarmonicTransformOperator(lm, domain[space], space)
diag = DiagonalOperator(kernel*domain[space].total_volume, lm, (space,))
wgt = WeightApplier(domain, space, 1)
return HT(diag(HT.adjoint(wgt)))
op = HT(diag(HT.adjoint(wgt)))
if without_mean:
return _ApplicationWithoutMeanOperator(domain, op)
return op
class _ApplicationWithoutMeanOperator(EndomorphicOperator):
def __init__(self, domain, op):
self._capability = self.TIMES | self.ADJOINT_TIMES
self._domain = domain
self._op = op
if (op.domain != domain) or (op.domain != op.target):
raise TypeError("domains incompatible")
def apply(self, x, mode):
self._check_input(x, mode)
mean = x.mean()
return mean + self._op.apply(x - mean, mode)
def __repr__(self):
from ..utilities import indent
return "\n".join((