diff --git a/demos/misc/convolution_on_sphere.py b/demos/misc/convolution_on_sphere.py index 24e1a58338be692e53e773f472aef3308395761c..79e9a59151a433f73c020f8c1add05a9c2eed3a8 100644 --- a/demos/misc/convolution_on_sphere.py +++ b/demos/misc/convolution_on_sphere.py @@ -15,11 +15,13 @@ for i in range(0, npix, npix//12 + 27): signal_vals[i] = 1. signal = ift.from_global_data(dom_tuple, signal_vals) + # Define kernel function def func(theta): ct = np.cos(theta) return 1. * np.logical_and(ct > 0.7, ct <= 0.8) + # Create Convolution Operator conv_op = ift.SphericalFuncConvolutionOperator(dom_tuple, func) diff --git a/nifty5/operators/convolution_operators.py b/nifty5/operators/convolution_operators.py index 26144bceb7dc6a104257e6ec0f68888258673d80..6648d433ce5ca277480de01d2c662f994da3da56 100644 --- a/nifty5/operators/convolution_operators.py +++ b/nifty5/operators/convolution_operators.py @@ -26,32 +26,25 @@ from ..domain_tuple import DomainTuple from ..field import Field -class SphericalFuncConvolutionOperator(EndomorphicOperator): +def SphericalFuncConvolutionOperator(domain, func): """Convolves input with a radially symmetric kernel defined by `func` Parameters ---------- - domain: domain of the operator - func: function defining the sperical convolution kernel - dependant only on theta in radians + domain: DomainTuple + Domain of the operator. Must have exactly one entry, which is + of type `HPSpace` or `GLSpace`. + func: function + This function needs to take exactly one argument, which is + colatitude in radians, and return the kernel amplitude at that + colatitude. """ - - def __init__(self, domain, func): - if len(domain) != 1: - raise ValueError("need exactly one domain") - if not isinstance(domain[0], (HPSpace, GLSpace)): - raise TypeError("need a spherical domain") - self._domain = domain - self.lm = domain[0].get_default_codomain() - self.kernel = self.lm.get_conv_kernel_from_func(func) - self.HT = HarmonicTransformOperator(self.lm, domain[0]) - self._capability = self.TIMES | self.ADJOINT_TIMES - - def apply(self, x, mode): - self._check_input(x, mode) - x_lm = self.HT.adjoint_times(x.weight(1)) - x_lm = x_lm * self.kernel * (4. * np.pi) - return self.HT(x_lm) + if len(domain) != 1: + raise ValueError("need exactly one domain") + if not isinstance(domain[0], (HPSpace, GLSpace)): + raise TypeError("need a spherical domain") + kernel = domain[0].get_default_codomain().get_conv_kernel_from_func(func) + return SphericalConvolutionOperator(domain, kernel) class SphericalConvolutionOperator(EndomorphicOperator):