From 2f5a515de947dd6edca361921a7a3cda2a97fcfd Mon Sep 17 00:00:00 2001
From: theos <theo.steininger@ultimanet.de>
Date: Fri, 19 Aug 2016 22:22:32 +0200
Subject: [PATCH] First steps towards a new Operators base class.

---
 nifty/__init__.py                             |  3 +-
 nifty/operators/operator/operator.py          | 85 +++++++++++++++++++
 nifty/operators/operator/operator_paradict.py |  7 ++
 nifty/paradict.py                             | 38 +++++++++
 nifty/spaces/power_space/power_indices.py     |  5 --
 nifty/spaces/space/space_paradict.py          | 38 +--------
 .../transformations/transformation_factory.py |  4 +-
 7 files changed, 137 insertions(+), 43 deletions(-)
 create mode 100644 nifty/operators/operator/operator.py
 create mode 100644 nifty/operators/operator/operator_paradict.py
 create mode 100644 nifty/paradict.py

diff --git a/nifty/__init__.py b/nifty/__init__.py
index 7124c420a..bfffe2a63 100644
--- a/nifty/__init__.py
+++ b/nifty/__init__.py
@@ -39,6 +39,7 @@ from d2o import distributed_data_object, d2o_librarian
 
 from nifty_cmaps import ncmap
 from field import Field
+from paradict import Paradict
 
 # this line exists for compatibility reasons
 # TODO: Remove this once the transition to field types is done.
@@ -58,4 +59,4 @@ from spaces import *
 from demos import get_demo_dir
 
 #import pyximport; pyximport.install(pyimport = True)
-from transformations import *
\ No newline at end of file
+from transformations import *
diff --git a/nifty/operators/operator/operator.py b/nifty/operators/operator/operator.py
new file mode 100644
index 000000000..b2c811358
--- /dev/null
+++ b/nifty/operators/operator/operator.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+
+from nifty.config import about
+from operator_paradict import OperatorParadict
+
+
+class LinearOperator(object):
+
+    def __init__(self, domain=None, target=None,
+                 field_type=None, field_type_target=None,
+                 implemented=False, symmetric=False, unitary=False,
+                 **kwargs):
+        self.paradict = OperatorParadict(**kwargs)
+
+        self.implemented = implemented
+        self.symmetric = symmetric
+        self.unitary = unitary
+
+    @property
+    def implemented(self):
+        return self._implemented
+
+    @implemented.setter
+    def implemented(self, b):
+        self._implemented = bool(b)
+
+    @property
+    def symmetric(self):
+        return self._symmetric
+
+    @symmetric.setter
+    def symmetric(self, b):
+        self._symmetric = bool(b)
+
+    @property
+    def unitary(self):
+        return self._unitary
+
+    @unitary.setter
+    def unitary(self, b):
+        self._unitary = bool(b)
+
+    def times(self, x, spaces=None, types=None):
+        raise NotImplementedError
+
+    def adjoint_times(self, x, spaces=None, types=None):
+        raise NotImplementedError
+
+    def inverse_times(self, x, spaces=None, types=None):
+        raise NotImplementedError
+
+    def adjoint_inverse_times(self, x, spaces=None, types=None):
+        raise NotImplementedError
+
+    def inverse_adjoint_times(self, x, spaces=None, types=None):
+        raise NotImplementedError
+
+    def _times(self, x, **kwargs):
+        raise NotImplementedError(about._errors.cstring(
+            "ERROR: no generic instance method 'times'."))
+
+    def _adjoint_times(self, x, **kwargs):
+        raise NotImplementedError(about._errors.cstring(
+            "ERROR: no generic instance method 'adjoint_times'."))
+
+    def _inverse_times(self, x, **kwargs):
+        raise NotImplementedError(about._errors.cstring(
+            "ERROR: no generic instance method 'inverse_times'."))
+
+    def _adjoint_inverse_times(self, x, **kwargs):
+        raise NotImplementedError(about._errors.cstring(
+            "ERROR: no generic instance method 'adjoint_inverse_times'."))
+
+    def _inverse_adjoint_times(self, x, **kwargs):
+        raise NotImplementedError(about._errors.cstring(
+            "ERROR: no generic instance method 'inverse_adjoint_times'."))
+
+    def _check_input_compatibility(self, x, spaces, types):
+        # assert: x is a field
+        # if spaces is None -> assert f.domain == self.domain
+        # -> same for field_type
+        # else: check if self.domain/self.field_type == one entry.
+        #
+
+
diff --git a/nifty/operators/operator/operator_paradict.py b/nifty/operators/operator/operator_paradict.py
new file mode 100644
index 000000000..4e52353bf
--- /dev/null
+++ b/nifty/operators/operator/operator_paradict.py
@@ -0,0 +1,7 @@
+# -*- coding: utf-8 -*-
+
+from nifty.paradict import Paradict
+
+
+class OperatorParadict(Paradict):
+    pass
diff --git a/nifty/paradict.py b/nifty/paradict.py
new file mode 100644
index 000000000..c4c0c0045
--- /dev/null
+++ b/nifty/paradict.py
@@ -0,0 +1,38 @@
+# -*- coding: utf-8 -*-
+
+class Paradict(object):
+
+    def __init__(self, **kwargs):
+        if not hasattr(self, 'parameters'):
+            self.parameters = {}
+        for key in kwargs:
+            self[key] = kwargs[key]
+
+    def __iter__(self):
+        return self.parameters.__iter__()
+
+    def __eq__(self, other):
+        return (isinstance(other, self.__class__) and
+                self.__dict__ == other.__dict__)
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __repr__(self):
+        return self.parameters.__repr__()
+
+    def __setitem__(self, key, arg):
+        raise NotImplementedError
+
+    def __getitem__(self, key):
+        return self.parameters.__getitem__(key)
+
+    def __hash__(self):
+        result_hash = 0
+        for (key, item) in self.parameters.items():
+            try:
+                temp_hash = hash(item)
+            except TypeError:
+                temp_hash = hash(tuple(item))
+            result_hash ^= temp_hash ^ int(hash(key)/131)
+        return result_hash
diff --git a/nifty/spaces/power_space/power_indices.py b/nifty/spaces/power_space/power_indices.py
index e246bdb5b..07d9c1140 100644
--- a/nifty/spaces/power_space/power_indices.py
+++ b/nifty/spaces/power_space/power_indices.py
@@ -137,11 +137,6 @@ class PowerIndices(object):
                      "binbounds": temp_binbounds}
         return temp_dict
 
-    def compute_k_array(self):
-        raise NotImplementedError(
-            about._errors.cstring(
-                "ERROR: No generic compute_k_array method implemented."))
-
     def get_index_dict(self, **kwargs):
         """
             Returns a dictionary containing the pindex, kindex, rho and pundex
diff --git a/nifty/spaces/space/space_paradict.py b/nifty/spaces/space/space_paradict.py
index 1c38db112..4cc839c06 100644
--- a/nifty/spaces/space/space_paradict.py
+++ b/nifty/spaces/space/space_paradict.py
@@ -1,39 +1,7 @@
 # -*- coding: utf-8 -*-
 
+from nifty.paradict import Paradict
 
-class SpaceParadict(object):
 
-    def __init__(self, **kwargs):
-        if not hasattr(self, 'parameters'):
-            self.parameters = {}
-        for key in kwargs:
-            self[key] = kwargs[key]
-
-    def __iter__(self):
-        return self.parameters.__iter__()
-
-    def __eq__(self, other):
-        return (isinstance(other, self.__class__) and
-                self.__dict__ == other.__dict__)
-
-    def __ne__(self, other):
-        return not self.__eq__(other)
-
-    def __repr__(self):
-        return self.parameters.__repr__()
-
-    def __setitem__(self, key, arg):
-        raise NotImplementedError
-
-    def __getitem__(self, key):
-        return self.parameters.__getitem__(key)
-
-    def __hash__(self):
-        result_hash = 0
-        for (key, item) in self.parameters.items():
-            try:
-                temp_hash = hash(item)
-            except TypeError:
-                temp_hash = hash(tuple(item))
-            result_hash ^= temp_hash ^ int(hash(key)/131)
-        return result_hash
+class SpaceParadict(Paradict):
+    pass
diff --git a/nifty/transformations/transformation_factory.py b/nifty/transformations/transformation_factory.py
index 8d5275924..fd3501864 100644
--- a/nifty/transformations/transformation_factory.py
+++ b/nifty/transformations/transformation_factory.py
@@ -24,13 +24,13 @@ class _TransformationFactory(object):
                 raise ValueError('ERROR: incompatible codomain')
 
         elif isinstance(domain, GLSpace):
-            if isinstance(codomain, GLSpace):
+            if isinstance(codomain, LMSpace):
                 return GLLMTransformation(domain, codomain, module)
             else:
                 raise ValueError('ERROR: incompatible codomain')
 
         elif isinstance(domain, HPSpace):
-            if isinstance(codomain, GLSpace):
+            if isinstance(codomain, LMSpace):
                 return HPLMTransformation(domain, codomain, module)
             else:
                 raise ValueError('ERROR: incompatible codomain')
-- 
GitLab