Skip to content
Snippets Groups Projects
Commit 2f5a515d authored by theos's avatar theos
Browse files

First steps towards a new Operators base class.

parent 30623565
No related branches found
No related tags found
1 merge request!29Nif ty3 temp
......@@ -39,6 +39,7 @@ from d2o import distributed_data_object, d2o_librarian
from nifty_cmaps import ncmap
from field import Field
from paradict import Paradict
# this line exists for compatibility reasons
# TODO: Remove this once the transition to field types is done.
......@@ -58,4 +59,4 @@ from spaces import *
from demos import get_demo_dir
#import pyximport; pyximport.install(pyimport = True)
from transformations import *
\ No newline at end of file
from transformations import *
# -*- coding: utf-8 -*-
from nifty.config import about
from operator_paradict import OperatorParadict
class LinearOperator(object):
def __init__(self, domain=None, target=None,
field_type=None, field_type_target=None,
implemented=False, symmetric=False, unitary=False,
**kwargs):
self.paradict = OperatorParadict(**kwargs)
self.implemented = implemented
self.symmetric = symmetric
self.unitary = unitary
@property
def implemented(self):
return self._implemented
@implemented.setter
def implemented(self, b):
self._implemented = bool(b)
@property
def symmetric(self):
return self._symmetric
@symmetric.setter
def symmetric(self, b):
self._symmetric = bool(b)
@property
def unitary(self):
return self._unitary
@unitary.setter
def unitary(self, b):
self._unitary = bool(b)
def times(self, x, spaces=None, types=None):
raise NotImplementedError
def adjoint_times(self, x, spaces=None, types=None):
raise NotImplementedError
def inverse_times(self, x, spaces=None, types=None):
raise NotImplementedError
def adjoint_inverse_times(self, x, spaces=None, types=None):
raise NotImplementedError
def inverse_adjoint_times(self, x, spaces=None, types=None):
raise NotImplementedError
def _times(self, x, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'times'."))
def _adjoint_times(self, x, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'adjoint_times'."))
def _inverse_times(self, x, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_times'."))
def _adjoint_inverse_times(self, x, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'adjoint_inverse_times'."))
def _inverse_adjoint_times(self, x, **kwargs):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'inverse_adjoint_times'."))
def _check_input_compatibility(self, x, spaces, types):
# assert: x is a field
# if spaces is None -> assert f.domain == self.domain
# -> same for field_type
# else: check if self.domain/self.field_type == one entry.
#
# -*- coding: utf-8 -*-
from nifty.paradict import Paradict
class OperatorParadict(Paradict):
pass
# -*- coding: utf-8 -*-
class Paradict(object):
def __init__(self, **kwargs):
if not hasattr(self, 'parameters'):
self.parameters = {}
for key in kwargs:
self[key] = kwargs[key]
def __iter__(self):
return self.parameters.__iter__()
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return self.parameters.__repr__()
def __setitem__(self, key, arg):
raise NotImplementedError
def __getitem__(self, key):
return self.parameters.__getitem__(key)
def __hash__(self):
result_hash = 0
for (key, item) in self.parameters.items():
try:
temp_hash = hash(item)
except TypeError:
temp_hash = hash(tuple(item))
result_hash ^= temp_hash ^ int(hash(key)/131)
return result_hash
......@@ -137,11 +137,6 @@ class PowerIndices(object):
"binbounds": temp_binbounds}
return temp_dict
def compute_k_array(self):
raise NotImplementedError(
about._errors.cstring(
"ERROR: No generic compute_k_array method implemented."))
def get_index_dict(self, **kwargs):
"""
Returns a dictionary containing the pindex, kindex, rho and pundex
......
# -*- coding: utf-8 -*-
from nifty.paradict import Paradict
class SpaceParadict(object):
def __init__(self, **kwargs):
if not hasattr(self, 'parameters'):
self.parameters = {}
for key in kwargs:
self[key] = kwargs[key]
def __iter__(self):
return self.parameters.__iter__()
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return self.parameters.__repr__()
def __setitem__(self, key, arg):
raise NotImplementedError
def __getitem__(self, key):
return self.parameters.__getitem__(key)
def __hash__(self):
result_hash = 0
for (key, item) in self.parameters.items():
try:
temp_hash = hash(item)
except TypeError:
temp_hash = hash(tuple(item))
result_hash ^= temp_hash ^ int(hash(key)/131)
return result_hash
class SpaceParadict(Paradict):
pass
......@@ -24,13 +24,13 @@ class _TransformationFactory(object):
raise ValueError('ERROR: incompatible codomain')
elif isinstance(domain, GLSpace):
if isinstance(codomain, GLSpace):
if isinstance(codomain, LMSpace):
return GLLMTransformation(domain, codomain, module)
else:
raise ValueError('ERROR: incompatible codomain')
elif isinstance(domain, HPSpace):
if isinstance(codomain, GLSpace):
if isinstance(codomain, LMSpace):
return HPLMTransformation(domain, codomain, module)
else:
raise ValueError('ERROR: incompatible codomain')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment