From 2f5a515de947dd6edca361921a7a3cda2a97fcfd Mon Sep 17 00:00:00 2001 From: theos <theo.steininger@ultimanet.de> Date: Fri, 19 Aug 2016 22:22:32 +0200 Subject: [PATCH] First steps towards a new Operators base class. --- nifty/__init__.py | 3 +- nifty/operators/operator/operator.py | 85 +++++++++++++++++++ nifty/operators/operator/operator_paradict.py | 7 ++ nifty/paradict.py | 38 +++++++++ nifty/spaces/power_space/power_indices.py | 5 -- nifty/spaces/space/space_paradict.py | 38 +-------- .../transformations/transformation_factory.py | 4 +- 7 files changed, 137 insertions(+), 43 deletions(-) create mode 100644 nifty/operators/operator/operator.py create mode 100644 nifty/operators/operator/operator_paradict.py create mode 100644 nifty/paradict.py diff --git a/nifty/__init__.py b/nifty/__init__.py index 7124c420a..bfffe2a63 100644 --- a/nifty/__init__.py +++ b/nifty/__init__.py @@ -39,6 +39,7 @@ from d2o import distributed_data_object, d2o_librarian from nifty_cmaps import ncmap from field import Field +from paradict import Paradict # this line exists for compatibility reasons # TODO: Remove this once the transition to field types is done. @@ -58,4 +59,4 @@ from spaces import * from demos import get_demo_dir #import pyximport; pyximport.install(pyimport = True) -from transformations import * \ No newline at end of file +from transformations import * diff --git a/nifty/operators/operator/operator.py b/nifty/operators/operator/operator.py new file mode 100644 index 000000000..b2c811358 --- /dev/null +++ b/nifty/operators/operator/operator.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- + +from nifty.config import about +from operator_paradict import OperatorParadict + + +class LinearOperator(object): + + def __init__(self, domain=None, target=None, + field_type=None, field_type_target=None, + implemented=False, symmetric=False, unitary=False, + **kwargs): + self.paradict = OperatorParadict(**kwargs) + + self.implemented = implemented + self.symmetric = symmetric + self.unitary = unitary + + @property + def implemented(self): + return self._implemented + + @implemented.setter + def implemented(self, b): + self._implemented = bool(b) + + @property + def symmetric(self): + return self._symmetric + + @symmetric.setter + def symmetric(self, b): + self._symmetric = bool(b) + + @property + def unitary(self): + return self._unitary + + @unitary.setter + def unitary(self, b): + self._unitary = bool(b) + + def times(self, x, spaces=None, types=None): + raise NotImplementedError + + def adjoint_times(self, x, spaces=None, types=None): + raise NotImplementedError + + def inverse_times(self, x, spaces=None, types=None): + raise NotImplementedError + + def adjoint_inverse_times(self, x, spaces=None, types=None): + raise NotImplementedError + + def inverse_adjoint_times(self, x, spaces=None, types=None): + raise NotImplementedError + + def _times(self, x, **kwargs): + raise NotImplementedError(about._errors.cstring( + "ERROR: no generic instance method 'times'.")) + + def _adjoint_times(self, x, **kwargs): + raise NotImplementedError(about._errors.cstring( + "ERROR: no generic instance method 'adjoint_times'.")) + + def _inverse_times(self, x, **kwargs): + raise NotImplementedError(about._errors.cstring( + "ERROR: no generic instance method 'inverse_times'.")) + + def _adjoint_inverse_times(self, x, **kwargs): + raise NotImplementedError(about._errors.cstring( + "ERROR: no generic instance method 'adjoint_inverse_times'.")) + + def _inverse_adjoint_times(self, x, **kwargs): + raise NotImplementedError(about._errors.cstring( + "ERROR: no generic instance method 'inverse_adjoint_times'.")) + + def _check_input_compatibility(self, x, spaces, types): + # assert: x is a field + # if spaces is None -> assert f.domain == self.domain + # -> same for field_type + # else: check if self.domain/self.field_type == one entry. + # + + diff --git a/nifty/operators/operator/operator_paradict.py b/nifty/operators/operator/operator_paradict.py new file mode 100644 index 000000000..4e52353bf --- /dev/null +++ b/nifty/operators/operator/operator_paradict.py @@ -0,0 +1,7 @@ +# -*- coding: utf-8 -*- + +from nifty.paradict import Paradict + + +class OperatorParadict(Paradict): + pass diff --git a/nifty/paradict.py b/nifty/paradict.py new file mode 100644 index 000000000..c4c0c0045 --- /dev/null +++ b/nifty/paradict.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +class Paradict(object): + + def __init__(self, **kwargs): + if not hasattr(self, 'parameters'): + self.parameters = {} + for key in kwargs: + self[key] = kwargs[key] + + def __iter__(self): + return self.parameters.__iter__() + + def __eq__(self, other): + return (isinstance(other, self.__class__) and + self.__dict__ == other.__dict__) + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return self.parameters.__repr__() + + def __setitem__(self, key, arg): + raise NotImplementedError + + def __getitem__(self, key): + return self.parameters.__getitem__(key) + + def __hash__(self): + result_hash = 0 + for (key, item) in self.parameters.items(): + try: + temp_hash = hash(item) + except TypeError: + temp_hash = hash(tuple(item)) + result_hash ^= temp_hash ^ int(hash(key)/131) + return result_hash diff --git a/nifty/spaces/power_space/power_indices.py b/nifty/spaces/power_space/power_indices.py index e246bdb5b..07d9c1140 100644 --- a/nifty/spaces/power_space/power_indices.py +++ b/nifty/spaces/power_space/power_indices.py @@ -137,11 +137,6 @@ class PowerIndices(object): "binbounds": temp_binbounds} return temp_dict - def compute_k_array(self): - raise NotImplementedError( - about._errors.cstring( - "ERROR: No generic compute_k_array method implemented.")) - def get_index_dict(self, **kwargs): """ Returns a dictionary containing the pindex, kindex, rho and pundex diff --git a/nifty/spaces/space/space_paradict.py b/nifty/spaces/space/space_paradict.py index 1c38db112..4cc839c06 100644 --- a/nifty/spaces/space/space_paradict.py +++ b/nifty/spaces/space/space_paradict.py @@ -1,39 +1,7 @@ # -*- coding: utf-8 -*- +from nifty.paradict import Paradict -class SpaceParadict(object): - def __init__(self, **kwargs): - if not hasattr(self, 'parameters'): - self.parameters = {} - for key in kwargs: - self[key] = kwargs[key] - - def __iter__(self): - return self.parameters.__iter__() - - def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) - - def __ne__(self, other): - return not self.__eq__(other) - - def __repr__(self): - return self.parameters.__repr__() - - def __setitem__(self, key, arg): - raise NotImplementedError - - def __getitem__(self, key): - return self.parameters.__getitem__(key) - - def __hash__(self): - result_hash = 0 - for (key, item) in self.parameters.items(): - try: - temp_hash = hash(item) - except TypeError: - temp_hash = hash(tuple(item)) - result_hash ^= temp_hash ^ int(hash(key)/131) - return result_hash +class SpaceParadict(Paradict): + pass diff --git a/nifty/transformations/transformation_factory.py b/nifty/transformations/transformation_factory.py index 8d5275924..fd3501864 100644 --- a/nifty/transformations/transformation_factory.py +++ b/nifty/transformations/transformation_factory.py @@ -24,13 +24,13 @@ class _TransformationFactory(object): raise ValueError('ERROR: incompatible codomain') elif isinstance(domain, GLSpace): - if isinstance(codomain, GLSpace): + if isinstance(codomain, LMSpace): return GLLMTransformation(domain, codomain, module) else: raise ValueError('ERROR: incompatible codomain') elif isinstance(domain, HPSpace): - if isinstance(codomain, GLSpace): + if isinstance(codomain, LMSpace): return HPLMTransformation(domain, codomain, module) else: raise ValueError('ERROR: incompatible codomain') -- GitLab