Commit cb713171 authored by theos's avatar theos
Browse files

Fixed the return type of the FFTOperator.

Fixed violation of LSP for implemented keyword/property.
parent 28b305d3
...@@ -49,14 +49,17 @@ class Field(object): ...@@ -49,14 +49,17 @@ class Field(object):
self.set_val(new_val=val, copy=copy) self.set_val(new_val=val, copy=copy)
def _parse_domain(self, domain, val): def _parse_domain(self, domain, val=None):
if domain is None: if domain is None:
if isinstance(val, Field): if isinstance(val, Field):
domain = val.domain domain = val.domain
else: else:
domain = () domain = ()
elif not isinstance(domain, tuple): elif isinstance(domain, Space):
domain = (domain,) domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain: for d in domain:
if not isinstance(d, Space): if not isinstance(d, Space):
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
...@@ -64,14 +67,16 @@ class Field(object): ...@@ -64,14 +67,16 @@ class Field(object):
"nifty.space.")) "nifty.space."))
return domain return domain
def _parse_field_type(self, field_type, val): def _parse_field_type(self, field_type, val=None):
if field_type is None: if field_type is None:
if isinstance(val, Field): if isinstance(val, Field):
field_type = val.field_type field_type = val.field_type
else: else:
field_type = () field_type = ()
elif not isinstance(field_type, tuple): elif isinstance(field_type, FieldType):
field_type = (field_type,) field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type: for ft in field_type:
if not isinstance(ft, FieldType): if not isinstance(ft, FieldType):
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
......
...@@ -16,11 +16,13 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -16,11 +16,13 @@ class DiagonalOperator(EndomorphicOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), implemented=False, def __init__(self, domain=(), field_type=(), implemented=False,
diagonal=None, bare=False, datamodel=None, copy=True): diagonal=None, bare=False, copy=True, datamodel=None):
super(DiagonalOperator, self).__init__(domain=domain, super(DiagonalOperator, self).__init__(domain=domain,
field_type=field_type, field_type=field_type,
implemented=implemented) implemented=implemented)
self._implemented = bool(implemented)
if datamodel is None: if datamodel is None:
if isinstance(diagonal, distributed_data_object): if isinstance(diagonal, distributed_data_object):
datamodel = diagonal.distribution_strategy datamodel = diagonal.distribution_strategy
...@@ -80,6 +82,10 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -80,6 +82,10 @@ class DiagonalOperator(EndomorphicOperator):
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
@property
def implemented(self):
return self._implemented
@property @property
def symmetric(self): def symmetric(self):
return self._symmetric return self._symmetric
...@@ -116,8 +122,16 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -116,8 +122,16 @@ class DiagonalOperator(EndomorphicOperator):
datamodel=self.datamodel, datamodel=self.datamodel,
copy=copy) copy=copy)
# weight if the given values were `bare` # weight if the given values were `bare` and `implemented` is True
f.weight(inplace=True) # do inverse weightening if the other way around
if bare and self.implemented:
# If `copy` is True, we won't change external data by weightening
# Otherwise, inplace weightening would change the external field
f.weight(inplace=copy)
elif not bare and not self.implemented:
# If `copy` is True, we won't change external data by weightening
# Otherwise, inplace weightening would change the external field
f.weight(inplace=copy, power=-1)
# check if the operator is symmetric: # check if the operator is symmetric:
self._symmetric = (f.val.imag == 0).all() self._symmetric = (f.val.imag == 0).all()
...@@ -127,4 +141,3 @@ class DiagonalOperator(EndomorphicOperator): ...@@ -127,4 +141,3 @@ class DiagonalOperator(EndomorphicOperator):
# store the diagonal-field # store the diagonal-field
self._diagonal = f self._diagonal = f
from transformations import * from transformations import *
from fft_operator import FFTOperator from fft_operator import FFTOperator
\ No newline at end of file
...@@ -8,41 +8,26 @@ class FFTOperator(LinearOperator): ...@@ -8,41 +8,26 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
def __init__(self, domain=(), field_type=(), target=(), def __init__(self, domain=(), field_type=(), target=None):
field_type_target=(), implemented=True):
super(FFTOperator, self).__init__(domain=domain, super(FFTOperator, self).__init__(domain=domain,
field_type=field_type, field_type=field_type)
implemented=implemented)
if self.domain == (): if len(self.domain) != 1:
raise TypeError(about._errors.cstring( raise ValueError(about._errors.cstring(
'ERROR: TransformationOperator needs a single space as ' 'ERROR: TransformationOperator accepts only exactly one '
'input domain.' 'space as input domain.'))
))
else:
if len(self.domain) > 1:
raise TypeError(about._errors.cstring(
'ERROR: TransformationOperator accepts only a single '
'space as input domain.'
))
if self.field_type != (): if self.field_type != ():
raise TypeError(about._errors.cstring( raise ValueError(about._errors.cstring(
'ERROR: TransformationOperator field-type has to be an ' 'ERROR: TransformationOperator field-type must be an '
'empty tuple.' 'empty tuple.'
)) ))
# currently not sanitizing the target if target is None:
self._target = self._parse_domain( target = utilities.get_default_codomain(self.domain[0])
utilities.get_default_codomain(self.domain[0])
)
self._field_type_target = self._parse_field_type(field_type_target)
if self.field_type_target != (): self._target = self._parse_domain(
raise TypeError(about._errors.cstring( utilities.get_default_codomain(self.domain[0]))
'ERROR: TransformationOperator target field-type has to be an '
'empty tuple.'
))
self._forward_transformation = TransformationFactory.create( self._forward_transformation = TransformationFactory.create(
self.domain[0], self.target[0] self.domain[0], self.target[0]
...@@ -52,24 +37,37 @@ class FFTOperator(LinearOperator): ...@@ -52,24 +37,37 @@ class FFTOperator(LinearOperator):
self.target[0], self.domain[0] self.target[0], self.domain[0]
) )
def adjoint_times(self, x, spaces=None, types=None): def _times(self, x, spaces, types):
return self.inverse_times(x, spaces, types) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
def adjoint_inverse_times(self, x, spaces=None, types=None): new_val = self._forward_transformation.transform(x.val, axes=spaces)
return self.times(x, spaces, types)
def inverse_adjoint_times(self, x, spaces=None, types=None): if spaces is None:
return self.times(x, spaces, types) result_domain = self.target
else:
result_domain = list(x.domain)
result_domain[spaces[0]] = self.target[0]
def _times(self, x, spaces, types): result_field = x.copy_empty(domain=result_domain)
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) result_field.set_val(new_val=new_val)
return self._forward_transformation.transform(x.val, axes=spaces) return result_field
def _inverse_times(self, x, spaces, types): def _inverse_times(self, x, spaces, types):
spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(x.domain))
return self._inverse_transformation.transform(x.val, axes=spaces) new_val = self._inverse_transformation.transform(x.val, axes=spaces)
if spaces is None:
result_domain = self.domain
else:
result_domain = list(x.domain)
result_domain[spaces[0]] = self.domain[0]
result_field = x.copy_empty(domain=result_domain)
result_field.set_val(new_val=new_val)
return result_field
# ---Mandatory properties and methods--- # ---Mandatory properties and methods---
...@@ -79,5 +77,12 @@ class FFTOperator(LinearOperator): ...@@ -79,5 +77,12 @@ class FFTOperator(LinearOperator):
@property @property
def field_type_target(self): def field_type_target(self):
return self._field_type_target return self.field_type
@property
def implemented(self):
return True
@property
def unitary(self):
return True
...@@ -12,10 +12,9 @@ import nifty.nifty_utilities as utilities ...@@ -12,10 +12,9 @@ import nifty.nifty_utilities as utilities
class LinearOperator(object): class LinearOperator(object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, domain=(), field_type=(), implemented=False): def __init__(self, domain=(), field_type=()):
self._domain = self._parse_domain(domain) self._domain = self._parse_domain(domain)
self._field_type = self._parse_field_type(field_type) self._field_type = self._parse_field_type(field_type)
self._implemented = bool(implemented)
@property @property
def domain(self): def domain(self):
...@@ -36,8 +35,11 @@ class LinearOperator(object): ...@@ -36,8 +35,11 @@ class LinearOperator(object):
def _parse_domain(self, domain): def _parse_domain(self, domain):
if domain is None: if domain is None:
domain = () domain = ()
elif not isinstance(domain, tuple): elif isinstance(domain, Space):
domain = (domain,) domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain: for d in domain:
if not isinstance(d, Space): if not isinstance(d, Space):
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
...@@ -48,17 +50,20 @@ class LinearOperator(object): ...@@ -48,17 +50,20 @@ class LinearOperator(object):
def _parse_field_type(self, field_type): def _parse_field_type(self, field_type):
if field_type is None: if field_type is None:
field_type = () field_type = ()
elif not isinstance(field_type, tuple): elif isinstance(field_type, FieldType):
field_type = (field_type,) field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type: for ft in field_type:
if not isinstance(ft, FieldType): if not isinstance(ft, FieldType):
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
"ERROR: Given object is not a nifty.FieldType.")) "ERROR: Given object is not a nifty.FieldType."))
return field_type return field_type
@property @abc.abstractproperty
def implemented(self): def implemented(self):
return self._implemented raise NotImplementedError
@abc.abstractproperty @abc.abstractproperty
def unitary(self): def unitary(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment