# This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # # Copyright(C) 2013-2018 Max-Planck-Society # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. from __future__ import division from builtins import range import numpy as np from . import utilities from .domain_tuple import DomainTuple from functools import reduce from . import dobj __all__ = ["Field", "sqrt", "exp", "log", "conjugate"] class Field(object): """ The discrete representation of a continuous field over multiple spaces. In NIFTy, Fields are used to store data arrays and carry all the needed metainformation (i.e. the domain) for operators to be able to work on them. Parameters ---------- domain : None, DomainTuple, tuple of Domain, or Domain val : None, Field, data_object, or scalar The values the array should contain after init. A scalar input will fill the whole array with this scalar. If a data_object is provided, its dimensions must match the domain's. dtype : type A numpy.type. Most common are float and complex. copy: bool """ def __init__(self, domain=None, val=None, dtype=None, copy=False): self._domain = self._infer_domain(domain=domain, val=val) dtype = self._infer_dtype(dtype=dtype, val=val) if isinstance(val, Field): if self._domain != val._domain: raise ValueError("Domain mismatch") self._val = dobj.from_object(val.val, dtype=dtype, copy=copy) elif (np.isscalar(val)): self._val = dobj.full(self._domain.shape, dtype=dtype, fill_value=val) elif isinstance(val, dobj.data_object): if self._domain.shape == val.shape: self._val = dobj.from_object(val, dtype=dtype, copy=copy) else: raise ValueError("Shape mismatch") elif val is None: self._val = dobj.empty(self._domain.shape, dtype=dtype) else: raise TypeError("unknown source type") @staticmethod def full(domain, val, dtype=None): if not np.isscalar(val): raise TypeError("val must be a scalar") return Field(DomainTuple.make(domain), val, dtype) @staticmethod def ones(domain, dtype=None): return Field(DomainTuple.make(domain), 1., dtype) @staticmethod def zeros(domain, dtype=None): return Field(DomainTuple.make(domain), 0., dtype) @staticmethod def empty(domain, dtype=None): return Field(DomainTuple.make(domain), None, dtype) @staticmethod def full_like(field, val, dtype=None): if not isinstance(field, Field): raise TypeError("field must be of Field type") return Field.full(field._domain, val, dtype) @staticmethod def zeros_like(field, dtype=None): if not isinstance(field, Field): raise TypeError("field must be of Field type") if dtype is None: dtype = field.dtype return Field.zeros(field._domain, dtype) @staticmethod def ones_like(field, dtype=None): if not isinstance(field, Field): raise TypeError("field must be of Field type") if dtype is None: dtype = field.dtype return Field.ones(field._domain, dtype) @staticmethod def empty_like(field, dtype=None): if not isinstance(field, Field): raise TypeError("field must be of Field type") if dtype is None: dtype = field.dtype return Field.empty(field._domain, dtype) @staticmethod def _infer_domain(domain, val=None): if domain is None: if isinstance(val, Field): return val._domain if np.isscalar(val): return DomainTuple.make(()) # empty domain tuple raise TypeError("could not infer domain from value") return DomainTuple.make(domain) @staticmethod def _infer_dtype(dtype, val): if dtype is not None: return dtype if val is None: raise ValueError("could not infer dtype") if isinstance(val, Field): return val.dtype return np.result_type(val) @staticmethod def from_random(random_type, domain, dtype=np.float64, **kwargs): """ Draws a random field with the given parameters. Parameters ---------- random_type : str 'pm1', 'normal', 'uniform' are the supported arguments for this method. domain : DomainTuple The domain of the output random field dtype : type The datatype of the output random field Returns ------- Field The output object. """ domain = DomainTuple.make(domain) return Field(domain=domain, val=dobj.from_random(random_type, dtype=dtype, shape=domain.shape, **kwargs)) def fill(self, fill_value): self._val.fill(fill_value) def lock(self): dobj.lock(self._val) @property def locked(self): return dobj.locked(self._val) @property def val(self): """ Returns the data object associated with this Field. No copy is made. """ return self._val @property def dtype(self): return self._val.dtype @property def domain(self): return self._domain @property def shape(self): """ Returns the total shape of the Field's data array. Returns ------- tuple of int the dimensions of the spaces in domain. """ return self._domain.shape @property def size(self): """ Returns the total number of pixel-dimensions the field has. Effectively, all values from shape are multiplied. Returns ------- int the dimension of the Field. """ return self._domain.size @property def real(self): """ The real part of the field (data is not copied).""" if not np.issubdtype(self.dtype, np.complexfloating): return self return Field(self._domain, self.val.real) @property def imag(self): """ The imaginary part of the field (data is not copied).""" if not np.issubdtype(self.dtype, np.complexfloating): raise ValueError(".imag called on a non-complex Field") return Field(self._domain, self.val.imag) def copy(self): """ Returns a full copy of the Field. The returned object will be an identical copy of the original Field. Returns ------- Field An identical copy of 'self'. """ return Field(val=self, copy=True) def locked_copy(self): """ Returns a read-only version of the Field. If `self` is locked, returns `self`. Otherwise returns a locked copy of `self`. Returns ------- Field A read-only version of 'self'. """ if self.locked: return self res = Field(val=self, copy=True) res.lock() return res def scalar_weight(self, spaces=None): if np.isscalar(spaces): return self._domain[spaces].scalar_dvol if spaces is None: spaces = range(len(self._domain)) res = 1. for i in spaces: tmp = self._domain[i].scalar_dvol if tmp is None: return None res *= tmp return res def total_volume(self, spaces=None): if np.isscalar(spaces): return self._domain[spaces].total_volume if spaces is None: spaces = range(len(self._domain)) res = 1. for i in spaces: res *= self._domain[i].total_volume return res def weight(self, power=1, spaces=None, out=None): """ Weights the pixels of `self` with their invidual pixel-volume. Parameters ---------- power : number The pixels get weighted with the volume-factor**power. spaces : int or tuple of int Determines on which subspace the operation takes place. out : Field or None if not None, the result is returned in a new Field otherwise the contents of "out" are overwritten with the result. "out" may be identical to "self"! Returns ------- Field The weighted field. """ if out is None: out = self.copy() else: if out is not self: out.copy_content_from(self) spaces = utilities.parse_spaces(spaces, len(self._domain)) fct = 1. for ind in spaces: wgt = self._domain[ind].dvol if np.isscalar(wgt): fct *= wgt else: new_shape = np.ones(len(self.shape), dtype=np.int) new_shape[self._domain.axes[ind][0]: self._domain.axes[ind][-1]+1] = wgt.shape wgt = wgt.reshape(new_shape) if dobj.distaxis(self._val) >= 0 and ind == 0: # we need to distribute the weights along axis 0 wgt = dobj.local_data(dobj.from_global_data(wgt)) lout = dobj.local_data(out.val) lout *= wgt**power fct = fct**power if fct != 1.: out *= fct return out def vdot(self, x=None, spaces=None): """ Computes the volume-factor-aware dot product of 'self' with x. Parameters ---------- x : Field x must live on the same domain as `self`. spaces : None, int or tuple of int (default: None) The dot product is only carried out over the sub-domains in this tuple. If None, it is carried out over all sub-domains. Returns ------- float, complex, either scalar (for full dot products) or Field (for partial dot products) """ if not isinstance(x, Field): raise ValueError("The dot-partner must be an instance of " + "the NIFTy field class") if x._domain != self._domain: raise ValueError("Domain mismatch") ndom = len(self._domain) spaces = utilities.parse_spaces(spaces, ndom) if len(spaces) == ndom: return dobj.vdot(self.val, x.val) # If we arrive here, we have to do a partial dot product. # For the moment, do this the explicit, non-optimized way return (self.conjugate()*x).sum(spaces=spaces) def norm(self): """ Computes the L2-norm of the field values. Returns ------- float The L2-norm of the field values. """ return np.sqrt(np.abs(self.vdot(x=self))) def conjugate(self): """ Returns the complex conjugate of the field. Returns ------- Field The complex conjugated field. """ return Field(self._domain, self.val.conjugate()) # ---General unary/contraction methods--- def __pos__(self): return self.copy() def __neg__(self): return Field(self._domain, -self.val) def __abs__(self): return Field(self._domain, dobj.abs(self.val)) def _contraction_helper(self, op, spaces): if spaces is None: return getattr(self.val, op)() spaces = utilities.parse_spaces(spaces, len(self._domain)) axes_list = tuple(self._domain.axes[sp_index] for sp_index in spaces) if len(axes_list) > 0: axes_list = reduce(lambda x, y: x+y, axes_list) # perform the contraction on the data data = getattr(self.val, op)(axis=axes_list) # check if the result is scalar or if a result_field must be constr. if np.isscalar(data): return data else: return_domain = tuple(dom for i, dom in enumerate(self._domain) if i not in spaces) return Field(domain=return_domain, val=data, copy=False) def sum(self, spaces=None): return self._contraction_helper('sum', spaces) def integrate(self, spaces=None): swgt = self.scalar_weight(spaces) if swgt is not None: res = self.sum(spaces) res *= swgt return res tmp = self.weight(1, spaces=spaces) return tmp.sum(spaces) def prod(self, spaces=None): return self._contraction_helper('prod', spaces) def all(self, spaces=None): return self._contraction_helper('all', spaces) def any(self, spaces=None): return self._contraction_helper('any', spaces) def min(self, spaces=None): return self._contraction_helper('min', spaces) def max(self, spaces=None): return self._contraction_helper('max', spaces) def mean(self, spaces=None): if self.scalar_weight(spaces) is not None: return self._contraction_helper('mean', spaces) # MR FIXME: not very efficient tmp = self.weight(1) return tmp.sum(spaces)*(1./tmp.total_volume(spaces)) def var(self, spaces=None): if self.scalar_weight(spaces) is not None: return self._contraction_helper('var', spaces) # MR FIXME: not very efficient or accurate m1 = self.mean(spaces) if np.issubdtype(self.dtype, np.complexfloating): sq = abs(self)**2 m1 = abs(m1)**2 else: sq = self**2 m1 **= 2 return sq.mean(spaces) - m1 def std(self, spaces=None): if self.scalar_weight(spaces) is not None: return self._contraction_helper('std', spaces) return sqrt(self.var(spaces)) def copy_content_from(self, other): if not isinstance(other, Field): raise TypeError("argument must be a Field") if other._domain != self._domain: raise ValueError("domains are incompatible.") dobj.local_data(self.val)[()] = dobj.local_data(other.val)[()] def _binary_helper(self, other, op): # if other is a field, make sure that the domains match if isinstance(other, Field): if other._domain != self._domain: raise ValueError("domains are incompatible.") tval = getattr(self.val, op)(other.val) return self if tval is self.val else Field(self._domain, tval) if np.isscalar(other) or isinstance(other, dobj.data_object): tval = getattr(self.val, op)(other) return self if tval is self.val else Field(self._domain, tval) return NotImplemented def __add__(self, other): return self._binary_helper(other, op='__add__') def __radd__(self, other): return self._binary_helper(other, op='__radd__') def __iadd__(self, other): return self._binary_helper(other, op='__iadd__') def __sub__(self, other): return self._binary_helper(other, op='__sub__') def __rsub__(self, other): return self._binary_helper(other, op='__rsub__') def __isub__(self, other): return self._binary_helper(other, op='__isub__') def __mul__(self, other): return self._binary_helper(other, op='__mul__') def __rmul__(self, other): return self._binary_helper(other, op='__rmul__') def __imul__(self, other): return self._binary_helper(other, op='__imul__') def __div__(self, other): return self._binary_helper(other, op='__div__') def __truediv__(self, other): return self._binary_helper(other, op='__truediv__') def __rdiv__(self, other): return self._binary_helper(other, op='__rdiv__') def __rtruediv__(self, other): return self._binary_helper(other, op='__rtruediv__') def __idiv__(self, other): return self._binary_helper(other, op='__idiv__') def __pow__(self, other): return self._binary_helper(other, op='__pow__') def __rpow__(self, other): return self._binary_helper(other, op='__rpow__') def __ipow__(self, other): return self._binary_helper(other, op='__ipow__') def __repr__(self): return "" def __str__(self): return "nifty4.Field instance\n- domain = " + \ self._domain.__str__() + \ "\n- val = " + repr(self.val) # Arithmetic functions working on Fields def _math_helper(x, function, out): if not isinstance(x, Field): raise TypeError("This function only accepts Field objects.") if out is not None: if not isinstance(out, Field) or x._domain != out._domain: raise ValueError("Bad 'out' argument") function(x.val, out=out.val) return out else: return Field(domain=x._domain, val=function(x.val)) def sqrt(x, out=None): return _math_helper(x, dobj.sqrt, out) def exp(x, out=None): return _math_helper(x, dobj.exp, out) def log(x, out=None): return _math_helper(x, dobj.log, out) def tanh(x, out=None): return _math_helper(x, dobj.tanh, out) def conjugate(x, out=None): return _math_helper(x, dobj.conjugate, out)