Commit cb713171 authored by theos's avatar theos
Browse files

Fixed the return type of the FFTOperator.

Fixed violation of LSP for implemented keyword/property.
parent 28b305d3
......@@ -49,14 +49,17 @@ class Field(object):
self.set_val(new_val=val, copy=copy)
def _parse_domain(self, domain, val):
def _parse_domain(self, domain, val=None):
if domain is None:
if isinstance(val, Field):
domain = val.domain
else:
domain = ()
elif not isinstance(domain, tuple):
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(about._errors.cstring(
......@@ -64,14 +67,16 @@ class Field(object):
"nifty.space."))
return domain
def _parse_field_type(self, field_type, val):
def _parse_field_type(self, field_type, val=None):
if field_type is None:
if isinstance(val, Field):
field_type = val.field_type
else:
field_type = ()
elif not isinstance(field_type, tuple):
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(about._errors.cstring(
......
......@@ -16,11 +16,13 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), implemented=False,
diagonal=None, bare=False, datamodel=None, copy=True):
diagonal=None, bare=False, copy=True, datamodel=None):
super(DiagonalOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
self._implemented = bool(implemented)
if datamodel is None:
if isinstance(diagonal, distributed_data_object):
datamodel = diagonal.distribution_strategy
......@@ -80,6 +82,10 @@ class DiagonalOperator(EndomorphicOperator):
# ---Mandatory properties and methods---
@property
def implemented(self):
return self._implemented
@property
def symmetric(self):
return self._symmetric
......@@ -116,8 +122,16 @@ class DiagonalOperator(EndomorphicOperator):
datamodel=self.datamodel,
copy=copy)
# weight if the given values were `bare`
f.weight(inplace=True)
# weight if the given values were `bare` and `implemented` is True
# do inverse weightening if the other way around
if bare and self.implemented:
# If `copy` is True, we won't change external data by weightening
# Otherwise, inplace weightening would change the external field
f.weight(inplace=copy)
elif not bare and not self.implemented:
# If `copy` is True, we won't change external data by weightening
# Otherwise, inplace weightening would change the external field
f.weight(inplace=copy, power=-1)
# check if the operator is symmetric:
self._symmetric = (f.val.imag == 0).all()
......@@ -127,4 +141,3 @@ class DiagonalOperator(EndomorphicOperator):
# store the diagonal-field
self._diagonal = f
from transformations import *
from fft_operator import FFTOperator
\ No newline at end of file
from fft_operator import FFTOperator
......@@ -8,41 +8,26 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), target=(),
field_type_target=(), implemented=True):
def __init__(self, domain=(), field_type=(), target=None):
super(FFTOperator, self).__init__(domain=domain,
field_type=field_type,
implemented=implemented)
field_type=field_type)
if self.domain == ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator needs a single space as '
'input domain.'
))
else:
if len(self.domain) > 1:
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator accepts only a single '
'space as input domain.'
))
if len(self.domain) != 1:
raise ValueError(about._errors.cstring(
'ERROR: TransformationOperator accepts only exactly one '
'space as input domain.'))
if self.field_type != ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator field-type has to be an '
raise ValueError(about._errors.cstring(
'ERROR: TransformationOperator field-type must be an '
'empty tuple.'
))
# currently not sanitizing the target
self._target = self._parse_domain(
utilities.get_default_codomain(self.domain[0])
)
self._field_type_target = self._parse_field_type(field_type_target)
if target is None:
target = utilities.get_default_codomain(self.domain[0])
if self.field_type_target != ():
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator target field-type has to be an '
'empty tuple.'
))
self._target = self._parse_domain(
utilities.get_default_codomain(self.domain[0]))
self._forward_transformation = TransformationFactory.create(
self.domain[0], self.target[0]
......@@ -52,24 +37,37 @@ class FFTOperator(LinearOperator):
self.target[0], self.domain[0]
)
def adjoint_times(self, x, spaces=None, types=None):
return self.inverse_times(x, spaces, types)
def _times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
def adjoint_inverse_times(self, x, spaces=None, types=None):
return self.times(x, spaces, types)
new_val = self._forward_transformation.transform(x.val, axes=spaces)
def inverse_adjoint_times(self, x, spaces=None, types=None):
return self.times(x, spaces, types)
if spaces is None:
result_domain = self.target
else:
result_domain = list(x.domain)
result_domain[spaces[0]] = self.target[0]
def _times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
result_field = x.copy_empty(domain=result_domain)
result_field.set_val(new_val=new_val)
return self._forward_transformation.transform(x.val, axes=spaces)
return result_field
def _inverse_times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
return self._inverse_transformation.transform(x.val, axes=spaces)
new_val = self._inverse_transformation.transform(x.val, axes=spaces)
if spaces is None:
result_domain = self.domain
else:
result_domain = list(x.domain)
result_domain[spaces[0]] = self.domain[0]
result_field = x.copy_empty(domain=result_domain)
result_field.set_val(new_val=new_val)
return result_field
# ---Mandatory properties and methods---
......@@ -79,5 +77,12 @@ class FFTOperator(LinearOperator):
@property
def field_type_target(self):
return self._field_type_target
return self.field_type
@property
def implemented(self):
return True
@property
def unitary(self):
return True
......@@ -12,10 +12,9 @@ import nifty.nifty_utilities as utilities
class LinearOperator(object):
__metaclass__ = abc.ABCMeta
def __init__(self, domain=(), field_type=(), implemented=False):
def __init__(self, domain=(), field_type=()):
self._domain = self._parse_domain(domain)
self._field_type = self._parse_field_type(field_type)
self._implemented = bool(implemented)
@property
def domain(self):
......@@ -36,8 +35,11 @@ class LinearOperator(object):
def _parse_domain(self, domain):
if domain is None:
domain = ()
elif not isinstance(domain, tuple):
elif isinstance(domain, Space):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
raise TypeError(about._errors.cstring(
......@@ -48,17 +50,20 @@ class LinearOperator(object):
def _parse_field_type(self, field_type):
if field_type is None:
field_type = ()
elif not isinstance(field_type, tuple):
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(about._errors.cstring(
"ERROR: Given object is not a nifty.FieldType."))
return field_type
@property
@abc.abstractproperty
def implemented(self):
return self._implemented
raise NotImplementedError
@abc.abstractproperty
def unitary(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment