Skip to content
Snippets Groups Projects
Forked from ift / NIFTy
2817 commits behind, 32 commits ahead of the upstream repository.
tensor_dot_operator.py 11.77 KiB
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

import numpy as np

from .. import utilities
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from .linear_operator import LinearOperator

class TensorDotOperator(LinearOperator):
    """Contraction of the last few dimensions
    of a tensor with selected dimensions of input field.

    This operator supports scipy.sparse matrices and numpy arrays
    as the tensor to be contracted with. Its output coincides with the
    endomorphic MatrixProductOperator with the same input arguments
    whenever the latter is applicable.

    For numpy arrays, it can contract the tensor with any subspaces
    of the input.

    If the input arrays have more than one dimension, for
    scipy.sparse matrices the `flatten` keyword argument must be
    set to true. This means that the input field will be flattened
    before contracting with the tensor and reshaped to its original shape
    afterwards. Flattening is only supported when the domain and target
    are the same, and a target can't be specified if flatten=True.

    Arrays are tested regarding their compatibility with the
    called for application method.

    Flattening and subspace application are mutually exclusive.
    
    The target space type and distances can be specified. If
    unspecified, it defaults to an RGSpace with default distance
    convention. An exception is if the target space shape is the
    same as the domain's shape, in which case the default target
    space is the domain itself. Either a single domain or a
    DomainTuple of the valid shape can be set as the target.
    
    When applied to specific subspaces of the domain, the
    non-participating subspaces of the domain retain their positions
    in the target space. The order of other axes in the target space is
    the tensor's axes in their original order. This convention is for
    preserving the space's shape when the application is endomorphic.
    
    A technicality related to this point: The absolute positions of the
    non-participating subspaces of the domain cannot be different in the
    target space, thus the tensor must have enough unsummed axes to stand
    in the places of summed-over axes of the domain, if those summed-over
    axes are followed by any unsummed (inactive) axes. Example to make
    this clear: If the first 2 spaces of the domain are summed over in
    the contraction and the 3rd space doesn't participate, the
    tensor must have (at least) 2 axes that don't participate in the
    multiplication, so that these 2 axes can come before the 3rd subspace
    of the domain takes its position in the target space. Otherwise an 
    error will be raised when the operator is applied informing that 
    the tensor has too few extra dimensions.

    Parameters
    ----------
    domain: :class:`Domain` or :class:`DomainTuple`
        Domain of the operator.
    tensor: scipy.sparse matrix or numpy array
        Tensor of a shape whose last few axes should match the shape
        of the axes of the field to be summed over (if not 'flatten').
        If it is not a numpy array, needs to be applicable to the val
        array of input fields by `tensor.dot()`.
    spaces: int or tuple of int, optional
        The subdomain(s) of "domain" which the operator acts on.
        If None, it acts on all elements.
        Only possible for numpy array matrices.
        If `len(domain) > 1` and `flatten=False`, this parameter is
        mandatory.
    flatten: boolean, optional
        Whether the input value array should be flattened before
        contracting with the field and reshaped to its original shape
        afterwards.
        Needed for scipy.sparse matrices if `len(domain) > 1`.
    target: :class:`Domain` or :class:`DomainTuple`, optional
        Target of the operator. It must be of the valid shape, other
        parameters like domain type and distances are flexible. The
        default is an RGSpace with default distances convention.
    """
    def __init__(self, domain, tensor, spaces=None, flatten=False, target=None):
        self._capability = self.TIMES | self.ADJOINT_TIMES
        self._domain = DomainTuple.make(domain)

        tensor_dim = len(tensor.shape)
        domain_shape = domain.shape
        domain_dim = len(domain_shape)

        # take shortcut for trivial case
        if spaces is not None and len(self._domain.shape) == 1 and spaces == (0, ):
            spaces = None

        if spaces is None:
            self._spaces = None
            self._active_axes = utilities.my_sum(self._domain.axes)
            self._inactive_axes = ()
            if flatten:
                domain_shape = (utilities.my_product(domain_shape), )
                target_space_shape = domain_shape
                target_dim = len(target_space_shape)
            tensor_inactive_dim = tensor_dim - len(domain_shape)
            if tensor_inactive_dim < 0:
                raise ValueError("Domain has more dimensions than tensor.")
            target_space_shape = tensor.shape[:tensor_inactive_dim]
            target_dim = tensor_inactive_dim

        else:
            if flatten:
                raise ValueError(
                    "Cannot flatten input AND apply to a subspace")
            if not isinstance(tensor, np.ndarray):
                raise ValueError(
                    "Application to subspaces only supported for numpy array matrices.")
            
            self._spaces = utilities.parse_spaces(spaces, len(self._domain))
            active_axes = []
            self._inactive_axes = list(range(len(self._domain)))
            domain_shape = []
            for space_idx in self._spaces:
                domain_shape += self._domain[space_idx].shape
                active_axes += self._domain.axes[space_idx]
                if space_idx in self._inactive_axes:
                    self._inactive_axes.remove(space_idx)
            domain_shape = tuple(domain_shape)
            self._active_axes = tuple(active_axes)
            self._inactive_axes = tuple(self._inactive_axes)
            tensor_inactive_dim = tensor_dim - len(domain_shape)
            if tensor_inactive_dim < 0:
                raise ValueError("Domain has more dimensions than tensor.")
            
            target_dim = tensor_inactive_dim + len(self._inactive_axes)
            domain_dim = len(domain_shape)
            target_space_shape = []
            tensor_shape_idx = 0
            for i in range(target_dim):
                if i in tuple(self._inactive_axes):
                    for j in range(len(self._domain[i].shape)):
                        target_space_shape.append(self._domain[i].shape[j])
                else:
                    target_space_shape.append(tensor.shape[tensor_shape_idx])
                    tensor_shape_idx += 1
            target_space_shape = tuple(target_space_shape)

            self._tensor_last_n = tuple([-domain_dim + i for i in range(domain_dim)])
            self._tensor_first_n = np.arange(domain_dim)
            self._target_last_n = tuple([-len(self._inactive_axes) + i for
                                         i in range(len(self._inactive_axes))])

        # tensor_last_m is needed for adjoint application even if spaces = None
        self._tensor_last_m = tuple([-tensor_inactive_dim + i for i in
                                  range(tensor_inactive_dim)])
        self._target_axes = tuple(range(len(target_space_shape)))
        if spaces != None:
            self._field_axes = list(self._target_axes)
            for i in list(self._target_axes):
                if i in self._inactive_axes:
                    self._field_axes.remove(i)
            self._field_axes = tuple(self._field_axes)

        if target == None:
            if flatten:
                self._target = self._domain
            else:
                if target_space_shape == self._domain.shape:
                    default_target = self._domain
                elif target_dim != 0:
                    default_target = DomainTuple.make(RGSpace(shape = target_space_shape))
                else:
                    default_target = DomainTuple.make(None)
                self._target = default_target
        
        elif flatten:
            raise ValueError("Flattening is supported only for endomorphic application,"
                             + " and you can't specify a target.")
        elif target.shape == target_space_shape:
            self._target = DomainTuple.make(target)
        else:
            raise ValueError("Target space has invalid shape.\n"
                             +f"Its shape should be {target_space_shape}.")

        tensor_appl_shape = tensor.shape[tensor_inactive_dim:]
        if tensor_appl_shape != domain_shape:
            raise ValueError("Tensor doesn't fit with the domain.\n" +
                f"Shape of tensor axes used in summation: {tensor_appl_shape},\n " +
                f"Shape of domain axes used in summation: {domain_shape}.")

        self._tensor = tensor
        self._tensor_tr = tensor.transpose().conjugate()
        self._flatten = flatten

    def apply(self, x, mode):
        self._check_input(x, mode)
        times = (mode == self.TIMES)
        t = self._tensor if times else self._tensor_tr
        target = self._target if times else self._domain

        if self._spaces is None:
            if not self._flatten:
                if type(t) == np.ndarray:
                    if times:
                        res = np.tensordot(t, x.val, axes = len(self._domain.shape))
                    else:
                        tensor_axes = np.flip(self._tensor_last_m)
                        field_axes = self._target_axes
                        res = np.tensordot(t, x.val, axes=(tensor_axes, field_axes))
                        res = res.transpose()
                else:
                    res = t.dot(x.val)
            else:
                res = t.dot(x.val.ravel()).reshape(self._domain.shape)
            return Field(target, res)

        if times:
            tensor_axes = self._tensor_last_n
            move_axes = self._target_last_n
            res = np.tensordot(t, x.val, axes=(tensor_axes, self._active_axes))
            try:
                res = np.moveaxis(res, move_axes, self._inactive_axes)
            except np.AxisError:
                raise ValueError("The tensor has too few extra dimensions.\n" +
                                 "Number of dimensions in tensor:" +
                                 f"{len(t.shape)}\n")
        else:
            tensor_axes = np.flip(self._tensor_last_m)
            move_axes = np.flip(self._tensor_first_n)
            field_axes = self._field_axes
            try:
                res = np.tensordot(t, x.val, axes=(tensor_axes, field_axes))
            except ValueError as e:
                if e.args[0] == "shape-mismatch for sum":
                    raise ValueError("The tensor has too few extra dimensions.\n" +
                                      "Number of dimensions in tensor: " +
                                      f"{len(self._tensor.shape)}\n")
                else:
                    raise e
            res = np.moveaxis(res, move_axes, self._active_axes)
        return Field(target, res)