diff --git a/nifty/nifty_utilities.py b/nifty/nifty_utilities.py index 3ef644e0a14430b14a4ad14b356f8a7aff2bec2e..2bafae96500ea71a0e9183976ebd682088e80784 100644 --- a/nifty/nifty_utilities.py +++ b/nifty/nifty_utilities.py @@ -236,6 +236,14 @@ def cast_axis_to_tuple(axis, length): # shift negative indices to positive ones axis = tuple(item if (item >= 0) else (item + length) for item in axis) + + # remove duplicate entries + axis = tuple(set(axis)) + + # assert that all entries are elements in [0, length] + for elem in axis: + assert(0 <= elem < length) + return axis diff --git a/nifty/operators/__init__.py b/nifty/operators/__init__.py index c9b34eb874fb481bd44afd855b660c805e2f3912..0cc18d79353fa58764adf2c4fc965a1c1483e6d6 100644 --- a/nifty/operators/__init__.py +++ b/nifty/operators/__init__.py @@ -20,6 +20,13 @@ ## along with this program. If not, see <http://www.gnu.org/licenses/>. from __future__ import division + +from linear_operator import LinearOperator,\ + LinearOperatorParadict + +from square_operator import SquareOperator,\ + SquareOperatorParadict + from nifty_operators import operator,\ diagonal_operator,\ power_operator,\ diff --git a/nifty/operators/linear_operator/__init__.py b/nifty/operators/linear_operator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9bebe593f080465cce077c56bebf1c64d2cffb2 --- /dev/null +++ b/nifty/operators/linear_operator/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from linear_operator import LinearOperator +from linear_operator_paradict import LinearOperatorParadict diff --git a/nifty/operators/linear_operator/linear_operator.py b/nifty/operators/linear_operator/linear_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..e1ad40de06b544a7717fb5441a825a54eab51e95 --- /dev/null +++ b/nifty/operators/linear_operator/linear_operator.py @@ -0,0 +1,178 @@ +# -*- coding: utf-8 -*- + +from nifty.config import about +from nifty.field import Field +from nifty.spaces import Space +from nifty.field_types import FieldType +import nifty.nifty_utilities as utilities + +from linear_operator_paradict import LinearOperatorParadict + + +class LinearOperator(object): + + def __init__(self, domain=None, target=None, + field_type=None, field_type_target=None, + implemented=False, symmetric=False, unitary=False): + self.paradict = LinearOperatorParadict() + + self._implemented = bool(implemented) + + self.domain = self._parse_domain(domain) + self.target = self._parse_domain(target) + + self.field_type = self._parse_field_type(field_type) + self.field_type_target = self._parse_field_type(field_type_target) + + def _parse_domain(self, domain): + if domain is None: + domain = () + elif not isinstance(domain, tuple): + domain = (domain,) + for d in domain: + if not isinstance(d, Space): + raise TypeError(about._errors.cstring( + "ERROR: Given object contains something that is not a " + "nifty.space.")) + return domain + + def _parse_field_type(self, field_type): + if field_type is None: + field_type = () + elif not isinstance(field_type, tuple): + field_type = (field_type,) + for ft in field_type: + if not isinstance(ft, FieldType): + raise TypeError(about._errors.cstring( + "ERROR: Given object is not a nifty.FieldType.")) + return field_type + + @property + def implemented(self): + return self._implemented + + def __call__(self, *args, **kwargs): + return self.times(*args, **kwargs) + + def times(self, x, spaces=None, types=None): + spaces, types = self._check_input_compatibility(x, spaces, types) + + if not self.implemented: + x = x.weight(spaces=spaces) + + y = self._times(x, spaces, types) + return y + + def inverse_times(self, x, spaces=None, types=None): + spaces, types = self._check_input_compatibility(x, spaces, types) + + y = self._inverse_times(x, spaces, types) + if not self.implemented: + y = y.weight(power=-1, spaces=spaces) + return y + + def adjoint_times(self, x, spaces=None, types=None): + spaces, types = self._check_input_compatibility(x, spaces, types) + + if not self.implemented: + x = x.weight(spaces=spaces) + y = self._adjoint_times(x, spaces, types) + return y + + def adjoint_inverse_times(self, x, spaces=None, types=None): + spaces, types = self._check_input_compatibility(x, spaces, types) + + y = self._adjoint_inverse_times(x, spaces, types) + if not self.implemented: + y = y.weight(power=-1, spaces=spaces) + return y + + def inverse_adjoint_times(self, x, spaces=None, types=None): + spaces, types = self._check_input_compatibility(x, spaces, types) + + y = self._inverse_adjoint_times(x, spaces, types) + if not self.implemented: + y = y.weight(power=-1, spaces=spaces) + return y + + def _times(self, x, spaces, types): + raise NotImplementedError(about._errors.cstring( + "ERROR: no generic instance method 'times'.")) + + def _adjoint_times(self, x, spaces, types): + raise NotImplementedError(about._errors.cstring( + "ERROR: no generic instance method 'adjoint_times'.")) + + def _inverse_times(self, x, spaces, types): + raise NotImplementedError(about._errors.cstring( + "ERROR: no generic instance method 'inverse_times'.")) + + def _adjoint_inverse_times(self, x, spaces, types): + raise NotImplementedError(about._errors.cstring( + "ERROR: no generic instance method 'adjoint_inverse_times'.")) + + def _inverse_adjoint_times(self, x, spaces, types): + raise NotImplementedError(about._errors.cstring( + "ERROR: no generic instance method 'inverse_adjoint_times'.")) + + def _check_input_compatibility(self, x, spaces, types): + if not isinstance(x, Field): + raise ValueError(about._errors.cstring( + "ERROR: supplied object is not a `nifty.Field`.")) + + # sanitize the `spaces` and `types` input + spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) + types = utilities.cast_axis_to_tuple(types, len(x.field_type)) + + # if the operator's domain is set to something, there are two valid + # cases: + # 1. Case: + # The user specifies with `spaces` that the operators domain should + # be applied to a certain domain in the domain-tuple of x. This is + # only valid if len(self.domain)==1. + # 2. Case: + # The domains of self and x match completely. + + if spaces is None: + if self.domain != () and self.domain != x.domain: + raise ValueError(about._errors.cstring( + "ERROR: The operator's and and field's domains don't " + "match.")) + else: + if len(self.domain) > 1: + raise ValueError(about._errors.cstring( + "ERROR: Specifying `spaces` for operators with multiple " + "domain spaces is not valid.")) + elif len(spaces) != len(self.domain): + raise ValueError(about._errors.cstring( + "ERROR: Length of `spaces` does not match the number of " + "spaces in the operator's domain.")) + elif len(spaces) == 1: + if x.domain[spaces[0]] != self.domain[0]: + raise ValueError(about._errors.cstring( + "ERROR: The operator's and and field's domains don't " + "match.")) + + if types is None: + if self.field_type != () and self.field_type != x.field_type: + raise ValueError(about._errors.cstring( + "ERROR: The operator's and and field's field_types don't " + "match.")) + else: + if len(self.field_type) > 1: + raise ValueError(about._errors.cstring( + "ERROR: Specifying `types` for operators with multiple " + "field-types is not valid.")) + elif len(types) != len(self.field_type): + raise ValueError(about._errors.cstring( + "ERROR: Length of `types` does not match the number of " + "the operator's field-types.")) + elif len(types) == 1: + if x.field_type[types[0]] != self.field_type[0]: + raise ValueError(about._errors.cstring( + "ERROR: The operator's and and field's field_type " + "don't match.")) + return (spaces, types) + + def __repr__(self): + return str(self.__class__) diff --git a/nifty/operators/operator/operator_paradict.py b/nifty/operators/linear_operator/linear_operator_paradict.py similarity index 64% rename from nifty/operators/operator/operator_paradict.py rename to nifty/operators/linear_operator/linear_operator_paradict.py index 4e52353bf8571d764a7e15e83ccfa82a9b5556ef..2c5e1177c30b076cb59ee6fd95a2c4716b4c3da5 100644 --- a/nifty/operators/operator/operator_paradict.py +++ b/nifty/operators/linear_operator/linear_operator_paradict.py @@ -3,5 +3,5 @@ from nifty.paradict import Paradict -class OperatorParadict(Paradict): +class LinearOperatorParadict(Paradict): pass diff --git a/nifty/operators/operator/operator.py b/nifty/operators/operator/operator.py deleted file mode 100644 index b2c8113588cd1fb326bd658ef690d23229ae992a..0000000000000000000000000000000000000000 --- a/nifty/operators/operator/operator.py +++ /dev/null @@ -1,85 +0,0 @@ -# -*- coding: utf-8 -*- - -from nifty.config import about -from operator_paradict import OperatorParadict - - -class LinearOperator(object): - - def __init__(self, domain=None, target=None, - field_type=None, field_type_target=None, - implemented=False, symmetric=False, unitary=False, - **kwargs): - self.paradict = OperatorParadict(**kwargs) - - self.implemented = implemented - self.symmetric = symmetric - self.unitary = unitary - - @property - def implemented(self): - return self._implemented - - @implemented.setter - def implemented(self, b): - self._implemented = bool(b) - - @property - def symmetric(self): - return self._symmetric - - @symmetric.setter - def symmetric(self, b): - self._symmetric = bool(b) - - @property - def unitary(self): - return self._unitary - - @unitary.setter - def unitary(self, b): - self._unitary = bool(b) - - def times(self, x, spaces=None, types=None): - raise NotImplementedError - - def adjoint_times(self, x, spaces=None, types=None): - raise NotImplementedError - - def inverse_times(self, x, spaces=None, types=None): - raise NotImplementedError - - def adjoint_inverse_times(self, x, spaces=None, types=None): - raise NotImplementedError - - def inverse_adjoint_times(self, x, spaces=None, types=None): - raise NotImplementedError - - def _times(self, x, **kwargs): - raise NotImplementedError(about._errors.cstring( - "ERROR: no generic instance method 'times'.")) - - def _adjoint_times(self, x, **kwargs): - raise NotImplementedError(about._errors.cstring( - "ERROR: no generic instance method 'adjoint_times'.")) - - def _inverse_times(self, x, **kwargs): - raise NotImplementedError(about._errors.cstring( - "ERROR: no generic instance method 'inverse_times'.")) - - def _adjoint_inverse_times(self, x, **kwargs): - raise NotImplementedError(about._errors.cstring( - "ERROR: no generic instance method 'adjoint_inverse_times'.")) - - def _inverse_adjoint_times(self, x, **kwargs): - raise NotImplementedError(about._errors.cstring( - "ERROR: no generic instance method 'inverse_adjoint_times'.")) - - def _check_input_compatibility(self, x, spaces, types): - # assert: x is a field - # if spaces is None -> assert f.domain == self.domain - # -> same for field_type - # else: check if self.domain/self.field_type == one entry. - # - - diff --git a/nifty/operators/square_operator/__init__.py b/nifty/operators/square_operator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb72eb6872855fe158957989a93ca65e969c8b9 --- /dev/null +++ b/nifty/operators/square_operator/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- + +from square_operator import SquareOperator +from square_operator_paradict import SquareOperatorParadict diff --git a/nifty/operators/square_operator/square_operator.py b/nifty/operators/square_operator/square_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..5920857e7d27fe334b3fe2fc0cbe79f16f59e22f --- /dev/null +++ b/nifty/operators/square_operator/square_operator.py @@ -0,0 +1,99 @@ +# -*- coding: utf-8 -*- + +from nifty.config import about +from nifty.operators.linear_operator import LinearOperator +from square_operator_paradict import SquareOperatorParadict + + +class SquareOperator(LinearOperator): + + def __init__(self, domain=None, target=None, + field_type=None, field_type_target=None, + implemented=False, symmetric=False, unitary=False): + + if target is not None: + about.warnings.cprint( + "WARNING: Discarding given target for SquareOperator.") + target = domain + + if field_type_target is not None: + about.warnings.cprint( + "WARNING: Discarding given field_type_target for " + "SquareOperator.") + field_type_target = field_type + + LinearOperator.__init__(self, + domain=domain, + target=target, + field_type=field_type, + field_type_target=field_type_target, + implemented=implemented) + + self.paradict = SquareOperatorParadict(symmetric=symmetric, + unitary=unitary) + + def inverse_times(self, x, spaces=None, types=None): + if self.paradict['symmetric'] and self.paradict['unitary']: + return self.times(x, spaces, types) + else: + return LinearOperator.inverse_times(self, + x=x, + spaces=spaces, + types=types) + + def adjoint_times(self, x, spaces=None, types=None): + if self.paradict['symmetric']: + return self.times(x, spaces, types) + elif self.paradict['unitary']: + return self.inverse_times(x, spaces, types) + else: + return LinearOperator.adjoint_times(self, + x=x, + spaces=spaces, + types=types) + + def adjoint_inverse_times(self, x, spaces=None, types=None): + if self.paradict['symmetric']: + return self.inverse_times(x, spaces, types) + elif self.paradict['unitary']: + return self.times(x, spaces, types) + else: + return LinearOperator.adjoint_inverse_times(self, + x=x, + spaces=spaces, + types=types) + + def inverse_adjoint_times(self, x, spaces=None, types=None): + if self.paradict['symmetric']: + return self.inverse_times(x, spaces, types) + elif self.paradict['unitary']: + return self.times(x, spaces, types) + else: + return LinearOperator.inverse_adjoint_times(self, + x=x, + spaces=spaces, + types=types) + + def trace(self): + pass + + def inverse_trace(self): + pass + + def diagonal(self): + pass + + def inverse_diagonal(self): + pass + + def determinant(self): + pass + + def inverse_determinant(self): + pass + + def log_determinant(self): + pass + + def trace_log(self): + pass diff --git a/nifty/operators/square_operator/square_operator_paradict.py b/nifty/operators/square_operator/square_operator_paradict.py new file mode 100644 index 0000000000000000000000000000000000000000..be2c24c066fe035b681edd0df1411902d5b97bda --- /dev/null +++ b/nifty/operators/square_operator/square_operator_paradict.py @@ -0,0 +1,22 @@ +# -*- coding: utf-8 -*- + +from nifty.config import about +from nifty.operators.linear_operator import LinearOperatorParadict + + +class SquareOperatorParadict(LinearOperatorParadict): + def __init__(self, symmetric, unitary): + LinearOperatorParadict.__init__(self, + symmetric=symmetric, + unitary=unitary) + + def __setitem__(self, key, arg): + if key not in ['symmetric', 'unitary']: + raise ValueError(about._errors.cstring( + "ERROR: Unsupported SquareOperator parameter: " + key)) + if key == 'symmetric': + temp = bool(arg) + elif key == 'unitary': + temp = bool(arg) + + self.parameters.__setitem__(key, temp)