Skip to content
Snippets Groups Projects
field.py 20.14 KiB
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

from functools import reduce
import numpy as np

from . import utilities
from .domain_tuple import DomainTuple


class Field(object):
    """The discrete representation of a continuous field over multiple spaces.

    Stores data arrays and carries all the needed meta-information (i.e. the
    domain) for operators to be able to operate on them.

    Parameters
    ----------
    domain : DomainTuple
        The domain of the new Field.
    val : numpy.ndarray
        This object's shape must match the domain shape
        After construction, the object will no longer be writeable!

    Notes
    -----
    If possible, do not invoke the constructor directly, but use one of the
    many convenience functions for instantiation!
    """

    _scalar_dom = DomainTuple.scalar_domain()

    def __init__(self, domain, val):
        if not isinstance(domain, DomainTuple):
            raise TypeError("domain must be of type DomainTuple")
        if not isinstance(val, np.ndarray):
            if np.isscalar(val):
                val = np.full(domain.shape, val)
            else:
                raise TypeError("val must be of type numpy.ndarray")
        if domain.shape != val.shape:
            raise ValueError("shape mismatch between val and domain")
        self._domain = domain
        self._val = val
        self._val.flags.writeable = False

    @staticmethod
    def scalar(val):
        return Field(Field._scalar_dom, val)

    # prevent implicit conversion to bool
    def __nonzero__(self):
        raise TypeError("Field does not support implicit conversion to bool")

    def __bool__(self):
        raise TypeError("Field does not support implicit conversion to bool")

    @staticmethod
    def full(domain, val):
        """Creates a Field with a given domain, filled with a constant value.

        Parameters
        ----------
        domain : Domain, tuple of Domain, or DomainTuple
            Domain of the new Field.
        val : float/complex/int scalar
            Fill value. Data type of the field is inferred from val.

        Returns
        -------
        Field
            The newly created Field.
        """
        if not np.isscalar(val):
            raise TypeError("val must be a scalar")
        if not (np.isreal(val) or np.iscomplex(val)):
            raise TypeError("need arithmetic scalar")
        domain = DomainTuple.make(domain)
        return Field(domain, val)

    @staticmethod
    def from_raw(domain, arr):
        """Returns a Field constructed from `domain` and `arr`.

        Parameters
        ----------
        domain : DomainTuple, tuple of Domain, or Domain
            The domain of the new Field.
        arr : numpy.ndarray
            The data content to be used for the new Field.
            Its shape must match the shape of `domain`.
        """
        return Field(DomainTuple.make(domain), arr)

    def cast_domain(self, new_domain):
        """Returns a field with the same data, but a different domain

        Parameters
        ----------
        new_domain : Domain, tuple of Domain, or DomainTuple
            The domain for the returned field. Must be shape-compatible to
            `self`.

        Returns
        -------
        Field
            Field defined on `new_domain`, but with the same data as `self`.
        """
        return Field(DomainTuple.make(new_domain), self._val)

    @staticmethod
    def from_random(random_type, domain, dtype=np.float64, **kwargs):
        """Draws a random field with the given parameters.

        Parameters
        ----------
        random_type : 'pm1', 'normal', or 'uniform'
            The random distribution to use.
        domain : DomainTuple
            The domain of the output random Field.
        dtype : type
            The datatype of the output random Field.

        Returns
        -------
        Field
            The newly created Field.
        """
        from .random import Random
        domain = DomainTuple.make(domain)
        generator_function = getattr(Random, random_type)
        arr = generator_function(dtype=dtype, shape=domain.shape, **kwargs)
        return Field(domain, arr)

    @property
    def val(self):
        """numpy.ndarray : the array storing the field's entries.

        Notes
        -----
        The returned array is read-only.
        """
        return self._val

    def val_rw(self):
        """numpy.ndarray : a copy of the array storing the field's entries.
        """
        return self._val.copy()

    @property
    def dtype(self):
        """type : the data type of the field's entries"""
        return self._val.dtype

    @property
    def domain(self):
        """DomainTuple : the field's domain"""
        return self._domain

    @property
    def shape(self):
        """tuple of int : the concatenated shapes of all sub-domains"""
        return self._domain.shape

    @property
    def size(self):
        """int : total number of pixels in the field"""
        return self._domain.size

    @property
    def real(self):
        """Field : The real part of the field"""
        if utilities.iscomplextype(self.dtype):
            return Field(self._domain, self._val.real)
        return self

    @property
    def imag(self):
        """Field : The imaginary part of the field"""
        if not utilities.iscomplextype(self.dtype):
            raise ValueError(".imag called on a non-complex Field")
        return Field(self._domain, self._val.imag)

    def scalar_weight(self, spaces=None):
        """Returns the uniform volume element for a sub-domain of `self`.

        Parameters
        ----------
        spaces : int, tuple of int or None
            Indices of the sub-domains of the field's domain to be considered.
            If `None`, the entire domain is used.

        Returns
        -------
        float or None
            If the requested sub-domain has a uniform volume element, it is
            returned. Otherwise, `None` is returned.
        """
        return self._domain.scalar_weight(spaces)

    def total_volume(self, spaces=None):
        """Returns the total volume of the field's domain or of a subspace of it.

        Parameters
        ----------
        spaces : int, tuple of int or None
            Indices of the sub-domains of the field's domain to be considered.
            If `None`, the total volume of the whole domain is returned.

        Returns
        -------
        float
            the total volume of the requested (sub-)domain.
        """
        return self._domain.total_volume(spaces)

    def weight(self, power=1, spaces=None):
        """Weights the pixels of `self` with their invidual pixel volumes.

        Parameters
        ----------
        power : number
            The pixel values get multiplied with their volume-factor**power.

        spaces : None, int or tuple of int
            Determines on which sub-domain the operation takes place.
            If None, the entire domain is used.

        Returns
        -------
        Field
            The weighted field.
        """
        aout = self.val_rw()

        spaces = utilities.parse_spaces(spaces, len(self._domain))

        fct = 1.
        for ind in spaces:
            wgt = self._domain[ind].dvol
            if np.isscalar(wgt):
                fct *= wgt
            else:
                new_shape = np.ones(len(self.shape), dtype=np.int)
                new_shape[self._domain.axes[ind][0]:
                          self._domain.axes[ind][-1]+1] = wgt.shape
                wgt = wgt.reshape(new_shape)
                aout *= wgt**power
        fct = fct**power
        if fct != 1.:
            aout *= fct

        return Field(self._domain, aout)

    def outer(self, x):
        """Computes the outer product of 'self' with x.

        Parameters
        ----------
        x : Field

        Returns
        -------
        Field
            Defined on the product space of self.domain and x.domain.
        """
        if not isinstance(x, Field):
            raise TypeError("The multiplier must be an instance of " +
                            "the Field class")
        from .operators.outer_product_operator import OuterProduct
        return OuterProduct(self, x.domain)(x)

    def vdot(self, x=None, spaces=None):
        """Computes the dot product of 'self' with x.

        Parameters
        ----------
        x : Field
            x must be defined on the same domain as `self`.

        spaces : None, int or tuple of int
            The dot product is only carried out over the sub-domains in this
            tuple. If None, it is carried out over all sub-domains.
            Default: None.

        Returns
        -------
        float, complex, either scalar (for full dot products) or Field (for partial dot products).
        """
        if not isinstance(x, Field):
            raise TypeError("The dot-partner must be an instance of " +
                            "the Field class")

        if x._domain != self._domain:
            raise ValueError("Domain mismatch")

        ndom = len(self._domain)
        spaces = utilities.parse_spaces(spaces, ndom)

        if len(spaces) == ndom:
            return np.vdot(self._val, x._val)
        # If we arrive here, we have to do a partial dot product.
        # For the moment, do this the explicit, non-optimized way
        return (self.conjugate()*x).sum(spaces=spaces)

    def norm(self, ord=2):
        """Computes the L2-norm of the field values.

        Parameters
        ----------
        ord : int
            Accepted values: 1, 2, ..., np.inf. Default: 2.

        Returns
        -------
        float
            The L2-norm of the field values.
        """
        return np.linalg.norm(self._val.reshape(-1), ord=ord)

    def conjugate(self):
        """Returns the complex conjugate of the field.

        Returns
        -------
        Field
            The complex conjugated field.
        """
        if utilities.iscomplextype(self._val.dtype):
            return Field(self._domain, self._val.conjugate())
        return self

    # ---General unary/contraction methods---

    def __pos__(self):
        return self

    def __neg__(self):
        return Field(self._domain, -self._val)

    def __abs__(self):
        return Field(self._domain, abs(self._val))

    def _contraction_helper(self, op, spaces):
        if spaces is None:
            return getattr(self._val, op)()

        spaces = utilities.parse_spaces(spaces, len(self._domain))

        axes_list = tuple(self._domain.axes[sp_index] for sp_index in spaces)

        if len(axes_list) > 0:
            axes_list = reduce(lambda x, y: x+y, axes_list)

        # perform the contraction on the data
        data = getattr(self._val, op)(axis=axes_list)

        # check if the result is scalar or if a result_field must be constr.
        if np.isscalar(data):
            return data
        else:
            return_domain = tuple(dom
                                  for i, dom in enumerate(self._domain)
                                  if i not in spaces)

            return Field(DomainTuple.make(return_domain), data)

    def sum(self, spaces=None):
        """Sums up over the sub-domains given by `spaces`.

        Parameters
        ----------
        spaces : None, int or tuple of int
            The summation is only carried out over the sub-domains in this
            tuple. If None, it is carried out over all sub-domains.

        Returns
        -------
        Field or scalar
            The result of the summation. If it is carried out over the entire
            domain, this is a scalar, otherwise a Field.
        """
        return self._contraction_helper('sum', spaces)

    def integrate(self, spaces=None):
        """Integrates over the sub-domains given by `spaces`.

        Integration is performed by summing over `self` multiplied by its
        volume factors.

        Parameters
        ----------
        spaces : None, int or tuple of int
            The summation is only carried out over the sub-domains in this
            tuple. If None, it is carried out over all sub-domains.

        Returns
        -------
        Field or scalar
            The result of the integration. If it is carried out over the
            entire domain, this is a scalar, otherwise a Field.
        """
        swgt = self.scalar_weight(spaces)
        if swgt is not None:
            res = self.sum(spaces)
            res = res*swgt
            return res
        tmp = self.weight(1, spaces=spaces)
        return tmp.sum(spaces)

    def prod(self, spaces=None):
        """Computes the product over the sub-domains given by `spaces`.

        Parameters
        ----------
        spaces : None, int or tuple of int
            The operation is only carried out over the sub-domains in this
            tuple. If None, it is carried out over all sub-domains.
            Default: None.

        Returns
        -------
        Field or scalar
            The result of the product. If it is carried out over the entire
            domain, this is a scalar, otherwise a Field.
        """
        return self._contraction_helper('prod', spaces)

    def all(self, spaces=None):
        return self._contraction_helper('all', spaces)

    def any(self, spaces=None):
        return self._contraction_helper('any', spaces)

#     def min(self, spaces=None):
#         """Determines the minimum over the sub-domains given by `spaces`.
#
#         Parameters
#         ----------
#         spaces : None, int or tuple of int (default: None)
#             The operation is only carried out over the sub-domains in this
#             tuple. If None, it is carried out over all sub-domains.
#
#         Returns
#         -------
#         Field or scalar
#             The result of the operation. If it is carried out over the entire
#             domain, this is a scalar, otherwise a Field.
#         """
#         return self._contraction_helper('min', spaces)
#
#     def max(self, spaces=None):
#         """Determines the maximum over the sub-domains given by `spaces`.
#
#         Parameters
#         ----------
#         spaces : None, int or tuple of int (default: None)
#             The operation is only carried out over the sub-domains in this
#             tuple. If None, it is carried out over all sub-domains.
#
#         Returns
#         -------
#         Field or scalar
#             The result of the operation. If it is carried out over the entire
#             domain, this is a scalar, otherwise a Field.
#         """
#         return self._contraction_helper('max', spaces)

    def mean(self, spaces=None):
        """Determines the mean over the sub-domains given by `spaces`.

        ``x.mean(spaces)`` is equivalent to
        ``x.integrate(spaces)/x.total_volume(spaces)``.

        Parameters
        ----------
        spaces : None, int or tuple of int
            The operation is only carried out over the sub-domains in this
            tuple. If None, it is carried out over all sub-domains.

        Returns
        -------
        Field or scalar
            The result of the operation. If it is carried out over the entire
            domain, this is a scalar, otherwise a Field.
        """
        if self.scalar_weight(spaces) is not None:
            return self._contraction_helper('mean', spaces)
        # MR FIXME: not very efficient
        # MR FIXME: do we need "spaces" here?
        tmp = self.weight(1, spaces)
        return tmp.sum(spaces)*(1./tmp.total_volume(spaces))

    def var(self, spaces=None):
        """Determines the variance over the sub-domains given by `spaces`.

        Parameters
        ----------
        spaces : None, int or tuple of int
            The operation is only carried out over the sub-domains in this
            tuple. If None, it is carried out over all sub-domains.
            Default: None.

        Returns
        -------
        Field or scalar
            The result of the operation. If it is carried out over the entire
            domain, this is a scalar, otherwise a Field.
        """
        if self.scalar_weight(spaces) is not None:
            return self._contraction_helper('var', spaces)
        # MR FIXME: not very efficient or accurate
        m1 = self.mean(spaces)
        if utilities.iscomplextype(self.dtype):
            sq = abs(self-m1)**2
        else:
            sq = (self-m1)**2
        return sq.mean(spaces)

    def std(self, spaces=None):
        """Determines the standard deviation over the sub-domains given by
        `spaces`.

        ``x.std(spaces)`` is equivalent to ``sqrt(x.var(spaces))``.

        Parameters
        ----------
        spaces : None, int or tuple of int
            The operation is only carried out over the sub-domains in this
            tuple. If None, it is carried out over all sub-domains.
            Default: None.

        Returns
        -------
        Field or scalar
            The result of the operation. If it is carried out over the entire
            domain, this is a scalar, otherwise a Field.
        """
        from .sugar import sqrt
        if self.scalar_weight(spaces) is not None:
            return self._contraction_helper('std', spaces)
        return sqrt(self.var(spaces))

    def __repr__(self):
        return "<nifty6.Field>"

    def __str__(self):
        return "nifty6.Field instance\n- domain      = " + \
               self._domain.__str__() + \
               "\n- val         = " + repr(self._val)

    def extract(self, dom):
        if dom != self._domain:
            raise ValueError("domain mismatch")
        return self

    def extract_part(self, dom):
        if dom != self._domain:
            raise ValueError("domain mismatch")
        return self

    def unite(self, other):
        return self+other

    def flexible_addsub(self, other, neg):
        return self-other if neg else self+other

    def sigmoid(self):
        return 0.5*(1.+self.tanh())

    def clip(self, min=None, max=None):
        min = min.val if isinstance(min, Field) else min
        max = max.val if isinstance(max, Field) else max
        return Field(self._domain, np.clip(self._val, min, max))

    def one_over(self):
        return 1/self

    def _binary_op(self, other, op):
        # if other is a field, make sure that the domains match
        f = getattr(self._val, op)
        if isinstance(other, Field):
            if other._domain != self._domain:
                raise ValueError("domains are incompatible.")
            return Field(self._domain, f(other._val))
        if np.isscalar(other):
            return Field(self._domain, f(other))
        return NotImplemented


for op in ["__add__", "__radd__",
           "__sub__", "__rsub__",
           "__mul__", "__rmul__",
           "__truediv__", "__rtruediv__",
           "__floordiv__", "__rfloordiv__",
           "__pow__", "__rpow__",
           "__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
    def func(op):
        def func2(self, other):
            return self._binary_op(other, op)
        return func2
    setattr(Field, op, func(op))

for op in ["__iadd__", "__isub__", "__imul__", "__idiv__",
           "__itruediv__", "__ifloordiv__", "__ipow__"]:
    def func(op):
        def func2(self, other):
            raise TypeError(
                "In-place operations are deliberately not supported")
        return func2
    setattr(Field, op, func(op))

for f in ["sqrt", "exp", "log", "sin", "cos", "tan", "sinh", "cosh", "tanh",
          "absolute", "sinc", "sign", "log10", "log1p", "expm1"]:
    def func(f):
        def func2(self):
            return Field(self._domain, getattr(np, f)(self.val))
        return func2
    setattr(Field, f, func(f))