Commit 88d89ed6 authored by theos's avatar theos
Browse files

Finalized LinearOperator base class. Added a first version of a square operator.

parent 0448204d
......@@ -236,6 +236,14 @@ def cast_axis_to_tuple(axis, length):
# shift negative indices to positive ones
axis = tuple(item if (item >= 0) else (item + length) for item in axis)
# remove duplicate entries
axis = tuple(set(axis))
# assert that all entries are elements in [0, length]
for elem in axis:
assert(0 <= elem < length)
return axis
......
......@@ -20,6 +20,13 @@
## along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import division
from linear_operator import LinearOperator,\
LinearOperatorParadict
from square_operator import SquareOperator,\
SquareOperatorParadict
from nifty_operators import operator,\
diagonal_operator,\
power_operator,\
......
# -*- coding: utf-8 -*-
from linear_operator import LinearOperator
from linear_operator_paradict import LinearOperatorParadict
# -*- coding: utf-8 -*-
from nifty.config import about
from nifty.field import Field
from nifty.spaces import Space
from nifty.field_types import FieldType
import nifty.nifty_utilities as utilities
from linear_operator_paradict import LinearOperatorParadict
class LinearOperator(object):
def __init__(self, domain=None, target=None,
field_type=None, field_type_target=None,
implemented=False, symmetric=False, unitary=False):
self.paradict = LinearOperatorParadict()
self._implemented = bool(implemented)
self.domain = self._parse_domain(domain)
self.target = self._parse_domain(target)
self.field_type = self._parse_field_type(field_type)
self.field_type_target = self._parse_field_type(field_type_target)
def _parse_domain(self, domain):
if domain is None:
domain = ()
elif not isinstance(domain, tuple):
domain = (domain,)
for d in domain:
if not isinstance(d, Space):
raise TypeError(about._errors.cstring(
"ERROR: Given object contains something that is not a "
"nifty.space."))
return domain
def _parse_field_type(self, field_type):
if field_type is None:
field_type = ()
elif not isinstance(field_type, tuple):
field_type = (field_type,)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(about._errors.cstring(
"ERROR: Given object is not a nifty.FieldType."))
return field_type
@property
def implemented(self):
return self._implemented
def __call__(self, *args, **kwargs):
return self.times(*args, **kwargs)
def times(self, x, spaces=None, types=None):
spaces, types = self._check_input_compatibility(x, spaces, types)
if not self.implemented:
x = x.weight(spaces=spaces)
y = self._times(x, spaces, types)
return y
def inverse_times(self, x, spaces=None, types=None):
spaces, types = self._check_input_compatibility(x, spaces, types)
y = self._inverse_times(x, spaces, types)
if not self.implemented:
y = y.weight(power=-1, spaces=spaces)
return y
def adjoint_times(self, x, spaces=None, types=None):
spaces, types = self._check_input_compatibility(x, spaces, types)
if not self.implemented:
x = x.weight(spaces=spaces)
y = self._adjoint_times(x, spaces, types)
return y
def adjoint_inverse_times(self, x, spaces=None, types=None):
spaces, types = self._check_input_compatibility(x, spaces, types)
y = self._adjoint_inverse_times(x, spaces, types)
if not self.implemented:
y = y.weight(power=-1, spaces=spaces)
return y
def inverse_adjoint_times(self, x, spaces=None, types=None):
spaces, types = self._check_input_compatibility(x, spaces, types)
y = self._inverse_adjoint_times(x, spaces, types)
if not self.implemented:
y = y.weight(power=-1, spaces=spaces)
return y
def _times(self, x, spaces, types):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'times'."))
def _adjoint_times(self, x, spaces, types):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'adjoint_times'."))
def _inverse_times(self, x, spaces, types):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_times'."))
def _adjoint_inverse_times(self, x, spaces, types):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'adjoint_inverse_times'."))
def _inverse_adjoint_times(self, x, spaces, types):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_adjoint_times'."))
def _check_input_compatibility(self, x, spaces, types):
if not isinstance(x, Field):
raise ValueError(about._errors.cstring(
"ERROR: supplied object is not a `nifty.Field`."))
# sanitize the `spaces` and `types` input
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
types = utilities.cast_axis_to_tuple(types, len(x.field_type))
# if the operator's domain is set to something, there are two valid
# cases:
# 1. Case:
# The user specifies with `spaces` that the operators domain should
# be applied to a certain domain in the domain-tuple of x. This is
# only valid if len(self.domain)==1.
# 2. Case:
# The domains of self and x match completely.
if spaces is None:
if self.domain != () and self.domain != x.domain:
raise ValueError(about._errors.cstring(
"ERROR: The operator's and and field's domains don't "
"match."))
else:
if len(self.domain) > 1:
raise ValueError(about._errors.cstring(
"ERROR: Specifying `spaces` for operators with multiple "
"domain spaces is not valid."))
elif len(spaces) != len(self.domain):
raise ValueError(about._errors.cstring(
"ERROR: Length of `spaces` does not match the number of "
"spaces in the operator's domain."))
elif len(spaces) == 1:
if x.domain[spaces[0]] != self.domain[0]:
raise ValueError(about._errors.cstring(
"ERROR: The operator's and and field's domains don't "
"match."))
if types is None:
if self.field_type != () and self.field_type != x.field_type:
raise ValueError(about._errors.cstring(
"ERROR: The operator's and and field's field_types don't "
"match."))
else:
if len(self.field_type) > 1:
raise ValueError(about._errors.cstring(
"ERROR: Specifying `types` for operators with multiple "
"field-types is not valid."))
elif len(types) != len(self.field_type):
raise ValueError(about._errors.cstring(
"ERROR: Length of `types` does not match the number of "
"the operator's field-types."))
elif len(types) == 1:
if x.field_type[types[0]] != self.field_type[0]:
raise ValueError(about._errors.cstring(
"ERROR: The operator's and and field's field_type "
"don't match."))
return (spaces, types)
def __repr__(self):
return str(self.__class__)
......@@ -3,5 +3,5 @@
from nifty.paradict import Paradict
class OperatorParadict(Paradict):
class LinearOperatorParadict(Paradict):
pass
# -*- coding: utf-8 -*-
from nifty.config import about
from operator_paradict import OperatorParadict
class LinearOperator(object):
def __init__(self, domain=None, target=None,
field_type=None, field_type_target=None,
implemented=False, symmetric=False, unitary=False,
**kwargs):
self.paradict = OperatorParadict(**kwargs)
self.implemented = implemented
self.symmetric = symmetric
self.unitary = unitary
@property
def implemented(self):
return self._implemented
@implemented.setter
def implemented(self, b):
self._implemented = bool(b)
@property
def symmetric(self):
return self._symmetric
@symmetric.setter
def symmetric(self, b):
self._symmetric = bool(b)
@property
def unitary(self):
return self._unitary
@unitary.setter
def unitary(self, b):
self._unitary = bool(b)
def times(self, x, spaces=None, types=None):
raise NotImplementedError
def adjoint_times(self, x, spaces=None, types=None):
raise NotImplementedError
def inverse_times(self, x, spaces=None, types=None):
raise NotImplementedError
def adjoint_inverse_times(self, x, spaces=None, types=None):
raise NotImplementedError
def inverse_adjoint_times(self, x, spaces=None, types=None):
raise NotImplementedError
def _times(self, x, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'times'."))
def _adjoint_times(self, x, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'adjoint_times'."))
def _inverse_times(self, x, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_times'."))
def _adjoint_inverse_times(self, x, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'adjoint_inverse_times'."))
def _inverse_adjoint_times(self, x, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_adjoint_times'."))
def _check_input_compatibility(self, x, spaces, types):
# assert: x is a field
# if spaces is None -> assert f.domain == self.domain
# -> same for field_type
# else: check if self.domain/self.field_type == one entry.
#
# -*- coding: utf-8 -*-
from square_operator import SquareOperator
from square_operator_paradict import SquareOperatorParadict
# -*- coding: utf-8 -*-
from nifty.config import about
from nifty.operators.linear_operator import LinearOperator
from square_operator_paradict import SquareOperatorParadict
class SquareOperator(LinearOperator):
def __init__(self, domain=None, target=None,
field_type=None, field_type_target=None,
implemented=False, symmetric=False, unitary=False):
if target is not None:
about.warnings.cprint(
"WARNING: Discarding given target for SquareOperator.")
target = domain
if field_type_target is not None:
about.warnings.cprint(
"WARNING: Discarding given field_type_target for "
"SquareOperator.")
field_type_target = field_type
LinearOperator.__init__(self,
domain=domain,
target=target,
field_type=field_type,
field_type_target=field_type_target,
implemented=implemented)
self.paradict = SquareOperatorParadict(symmetric=symmetric,
unitary=unitary)
def inverse_times(self, x, spaces=None, types=None):
if self.paradict['symmetric'] and self.paradict['unitary']:
return self.times(x, spaces, types)
else:
return LinearOperator.inverse_times(self,
x=x,
spaces=spaces,
types=types)
def adjoint_times(self, x, spaces=None, types=None):
if self.paradict['symmetric']:
return self.times(x, spaces, types)
elif self.paradict['unitary']:
return self.inverse_times(x, spaces, types)
else:
return LinearOperator.adjoint_times(self,
x=x,
spaces=spaces,
types=types)
def adjoint_inverse_times(self, x, spaces=None, types=None):
if self.paradict['symmetric']:
return self.inverse_times(x, spaces, types)
elif self.paradict['unitary']:
return self.times(x, spaces, types)
else:
return LinearOperator.adjoint_inverse_times(self,
x=x,
spaces=spaces,
types=types)
def inverse_adjoint_times(self, x, spaces=None, types=None):
if self.paradict['symmetric']:
return self.inverse_times(x, spaces, types)
elif self.paradict['unitary']:
return self.times(x, spaces, types)
else:
return LinearOperator.inverse_adjoint_times(self,
x=x,
spaces=spaces,
types=types)
def trace(self):
pass
def inverse_trace(self):
pass
def diagonal(self):
pass
def inverse_diagonal(self):
pass
def determinant(self):
pass
def inverse_determinant(self):
pass
def log_determinant(self):
pass
def trace_log(self):
pass
# -*- coding: utf-8 -*-
from nifty.config import about
from nifty.operators.linear_operator import LinearOperatorParadict
class SquareOperatorParadict(LinearOperatorParadict):
def __init__(self, symmetric, unitary):
LinearOperatorParadict.__init__(self,
symmetric=symmetric,
unitary=unitary)
def __setitem__(self, key, arg):
if key not in ['symmetric', 'unitary']:
raise ValueError(about._errors.cstring(
"ERROR: Unsupported SquareOperator parameter: " + key))
if key == 'symmetric':
temp = bool(arg)
elif key == 'unitary':
temp = bool(arg)
self.parameters.__setitem__(key, temp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment