convolution_operators.py 4.18 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np

20
from ..domains.rg_space import RGSpace
21 22 23 24 25
from ..domains.lm_space import LMSpace
from ..domains.hp_space import HPSpace
from ..domains.gl_space import GLSpace
from .endomorphic_operator import EndomorphicOperator
from .harmonic_operators import HarmonicTransformOperator
26 27
from .diagonal_operator import DiagonalOperator
from .simple_linear_operators import WeightApplier
28 29
from ..domain_tuple import DomainTuple
from ..field import Field
30
from .. import utilities
31 32


33
def FuncConvolutionOperator(domain, func, space=None):
34 35 36 37
    """Convolves input with a radially symmetric kernel defined by `func`

    Parameters
    ----------
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
38
    domain: DomainTuple
39
        Domain of the operator.
Martin Reinecke's avatar
Martin Reinecke committed
40 41 42 43
    func: function
        This function needs to take exactly one argument, which is
        colatitude in radians, and return the kernel amplitude at that
        colatitude.
44 45 46 47
    space: int, optional
        The index of the subdomain on which the operator should act
        If None, it is set to 0 if `domain` contains exactly one space.
        `domain[space]` must be of type `RGSpace`, `HPSpace`, or `GLSpace`.
48
    without_mean: bool, optional
49 50 51
        If `None`, chooses domain-dependant default value:
        - `True` for spherical domains (`HPSpace`, `GLSpace`)
        - `False` for RGSpaces.
52 53
        If `True`, subtracts the input mean before applying the convolution
        and adds it back afterwards.
54 55 56

    Notes
    -----
Martin Reinecke's avatar
tweaks  
Martin Reinecke committed
57 58 59 60
    The operator assumes periodic boundaries in the input domain. This means
    for a sufficiently broad function a point source close to the boundary will
    blur into the opposite side of the image. Zero padding can be applied to
    avoid this behaviour.
61
    """
62 63 64 65 66 67
    domain = DomainTuple.make(domain)
    space = utilities.infer_space(domain, space)
    if not isinstance(domain[space], (RGSpace, HPSpace, GLSpace)):
        raise TypeError("unsupported domain")
    codomain = domain[space].get_default_codomain()
    kernel = codomain.get_conv_kernel_from_func(func)
68
    return _ConvolutionOperator(domain, kernel, space)
69 70


71
def _ConvolutionOperator(domain, kernel, space=None):
72 73 74 75 76 77 78 79 80 81 82 83 84 85
    domain = DomainTuple.make(domain)
    space = utilities.infer_space(domain, space)
    if len(kernel.domain) != 1:
        raise ValueError("kernel needs exactly one domain")
    if not isinstance(domain[space], (HPSpace, GLSpace, RGSpace)):
        raise TypeError("need RGSpace, HPSpace, or GLSpace")
    lm = [d for d in domain]
    lm[space] = lm[space].get_default_codomain()
    lm = DomainTuple.make(lm)
    if lm[space] != kernel.domain[0]:
        raise ValueError("Input domain and kernel are incompatible")
    HT = HarmonicTransformOperator(lm, domain[space], space)
    diag = DiagonalOperator(kernel*domain[space].total_volume, lm, (space,))
    wgt = WeightApplier(domain, space, 1)
86
    op = HT(diag(HT.adjoint(wgt)))
87
    return _ApplicationWithoutMeanOperator(op)
88 89 90


class _ApplicationWithoutMeanOperator(EndomorphicOperator):
Lukas Platz's avatar
cleanup  
Lukas Platz committed
91
    def __init__(self, op):
92
        self._capability = self.TIMES | self.ADJOINT_TIMES
Lukas Platz's avatar
cleanup  
Lukas Platz committed
93 94 95
        if op.domain != op.target:
            raise TypeError("Operator needs to be endomorphic")
        self._domain = op.domain
96 97 98 99 100 101 102 103 104 105 106 107
        self._op = op

    def apply(self, x, mode):
        self._check_input(x, mode)
        mean = x.mean()
        return mean + self._op.apply(x - mean, mode)

    def __repr__(self):
        from ..utilities import indent
        return "\n".join((
            "_ApplicationWithoutMeanOperator:",
            indent(self._op.__repr__())))