From a7bc5e41013a61c872f1ef5728ee46a1e90faa60 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Sat, 19 May 2018 21:13:15 +0200
Subject: [PATCH] step 1

---
 demos/critical_filtering.py               |  4 +-
 demos/krylov_sampling.py                  |  4 +-
 demos/nonlinear_critical_filter.py        |  4 +-
 demos/nonlinear_wiener_filter.py          |  4 +-
 demos/poisson_demo.py                     |  2 +-
 demos/wiener_filter_via_hamiltonian.py    |  2 +-
 nifty4/domain_tuple.py                    |  6 +-
 nifty4/domains/domain.py                  | 12 ++--
 nifty4/extra/operator_tests.py            | 10 +--
 nifty4/field.py                           | 52 ----------------
 nifty4/library/nonlinearities.py          |  7 ++-
 nifty4/multi/multi_domain.py              | 75 ++++++++++++++++++++++-
 nifty4/multi/multi_field.py               | 15 ++---
 nifty4/operators/linear_operator.py       |  4 --
 nifty4/operators/scaling_operator.py      |  2 +-
 nifty4/sugar.py                           | 46 ++++++++++++--
 test/test_energies/test_power.py          |  2 +-
 test/test_field.py                        |  8 +--
 test/test_minimization/test_minimizers.py |  2 +-
 19 files changed, 157 insertions(+), 104 deletions(-)

diff --git a/demos/critical_filtering.py b/demos/critical_filtering.py
index 0011e23f0..d18f94af7 100644
--- a/demos/critical_filtering.py
+++ b/demos/critical_filtering.py
@@ -69,8 +69,8 @@ if __name__ == "__main__":
     # Creating the mock data
     d = noiseless_data + n
 
-    m0 = ift.Field.full(h_space, 1e-7)
-    t0 = ift.Field.full(p_space, -4.)
+    m0 = ift.full(h_space, 1e-7)
+    t0 = ift.full(p_space, -4.)
     power0 = Distributor.times(ift.exp(0.5 * t0))
 
     plotdict = {"colormap": "Planck-like"}
diff --git a/demos/krylov_sampling.py b/demos/krylov_sampling.py
index 7b316e4db..f3275f98f 100644
--- a/demos/krylov_sampling.py
+++ b/demos/krylov_sampling.py
@@ -67,8 +67,8 @@ plt.legend()
 plt.savefig('Krylov_samples_residuals.png')
 plt.close()
 
-D_hat_old = ift.Field.zeros(x_space).to_global_data()
-D_hat_new = ift.Field.zeros(x_space).to_global_data()
+D_hat_old = ift.full(x_space, 0.).to_global_data()
+D_hat_new = ift.full(x_space, 0.).to_global_data()
 for i in range(N_samps):
     D_hat_old += sky(samps_old[i]).to_global_data()**2
     D_hat_new += sky(samps[i]).to_global_data()**2
diff --git a/demos/nonlinear_critical_filter.py b/demos/nonlinear_critical_filter.py
index f1b84ab5c..8477acd9e 100644
--- a/demos/nonlinear_critical_filter.py
+++ b/demos/nonlinear_critical_filter.py
@@ -69,8 +69,8 @@ if __name__ == "__main__":
     # Creating the mock data
     d = noiseless_data + n
 
-    m0 = ift.Field.full(h_space, 1e-7)
-    t0 = ift.Field.full(p_space, -4.)
+    m0 = ift.full(h_space, 1e-7)
+    t0 = ift.full(p_space, -4.)
     power0 = Distributor.times(ift.exp(0.5 * t0))
 
     IC1 = ift.GradientNormController(name="IC1", iteration_limit=100,
diff --git a/demos/nonlinear_wiener_filter.py b/demos/nonlinear_wiener_filter.py
index be5877157..d19e739ac 100644
--- a/demos/nonlinear_wiener_filter.py
+++ b/demos/nonlinear_wiener_filter.py
@@ -36,7 +36,7 @@ if __name__ == "__main__":
     d_space = R.target
 
     p_op = ift.create_power_operator(h_space, p_spec)
-    power = ift.sqrt(p_op(ift.Field.full(h_space, 1.)))
+    power = ift.sqrt(p_op(ift.full(h_space, 1.)))
 
     # Creating the mock data
     true_sky = nonlinearity(HT(power*sh))
@@ -57,7 +57,7 @@ if __name__ == "__main__":
     inverter = ift.ConjugateGradient(controller=ICI)
 
     # initial guess
-    m = ift.Field.full(h_space, 1e-7)
+    m = ift.full(h_space, 1e-7)
     map_energy = ift.library.NonlinearWienerFilterEnergy(
         m, d, R, nonlinearity, HT, power, N, S, inverter=inverter)
 
diff --git a/demos/poisson_demo.py b/demos/poisson_demo.py
index 6c62fa645..a065f589f 100644
--- a/demos/poisson_demo.py
+++ b/demos/poisson_demo.py
@@ -113,7 +113,7 @@ if __name__ == "__main__":
         d_domain, np.random.poisson(lam.local_data).astype(np.float64))
 
     # initial guess
-    psi0 = ift.Field.full(h_domain, 1e-7)
+    psi0 = ift.full(h_domain, 1e-7)
     energy = ift.library.PoissonEnergy(psi0, data, R0, nonlin, HT, Phi_h,
                                        inverter)
     IC1 = ift.GradientNormController(name="IC1", iteration_limit=200,
diff --git a/demos/wiener_filter_via_hamiltonian.py b/demos/wiener_filter_via_hamiltonian.py
index 124a625a2..e83de5c74 100644
--- a/demos/wiener_filter_via_hamiltonian.py
+++ b/demos/wiener_filter_via_hamiltonian.py
@@ -50,7 +50,7 @@ if __name__ == "__main__":
     inverter = ift.ConjugateGradient(controller=ctrl)
     controller = ift.GradientNormController(name="min", tol_abs_gradnorm=0.1)
     minimizer = ift.RelaxedNewton(controller=controller)
-    m0 = ift.Field.zeros(h_space)
+    m0 = ift.full(h_space, 0.)
 
     # Initialize Wiener filter energy
     energy = ift.library.WienerFilterEnergy(position=m0, d=d, R=R, N=N, S=S,
diff --git a/nifty4/domain_tuple.py b/nifty4/domain_tuple.py
index 21779f547..39eafa885 100644
--- a/nifty4/domain_tuple.py
+++ b/nifty4/domain_tuple.py
@@ -34,7 +34,9 @@ class DomainTuple(object):
     """
     _tupleCache = {}
 
-    def __init__(self, domain):
+    def __init__(self, domain, _callingfrommake=False):
+        if not _callingfrommake:
+            raise NotImplementedError
         self._dom = self._parse_domain(domain)
         self._axtuple = self._get_axes_tuple()
         shape_tuple = tuple(sp.shape for sp in self._dom)
@@ -72,7 +74,7 @@ class DomainTuple(object):
         obj = DomainTuple._tupleCache.get(domain)
         if obj is not None:
             return obj
-        obj = DomainTuple(domain)
+        obj = DomainTuple(domain, _callingfrommake=True)
         DomainTuple._tupleCache[domain] = obj
         return obj
 
diff --git a/nifty4/domains/domain.py b/nifty4/domains/domain.py
index 9f22a5baa..38db023bd 100644
--- a/nifty4/domains/domain.py
+++ b/nifty4/domains/domain.py
@@ -23,6 +23,8 @@ from ..utilities import NiftyMetaBase
 class Domain(NiftyMetaBase()):
     """The abstract class repesenting a (structured or unstructured) domain.
     """
+    def __init__(self):
+        self._hash = None
 
     @abc.abstractmethod
     def __repr__(self):
@@ -36,10 +38,12 @@ class Domain(NiftyMetaBase()):
         Only members that are explicitly added to
         :attr:`._needed_for_hash` will be used for hashing.
         """
-        result_hash = 0
-        for key in self._needed_for_hash:
-            result_hash ^= hash(vars(self)[key])
-        return result_hash
+        if self._hash is None:
+            h = 0
+            for key in self._needed_for_hash:
+                h ^= hash(vars(self)[key])
+            self._hash = h
+        return self._hash
 
     def __eq__(self, x):
         """Checks whether two domains are equal.
diff --git a/nifty4/extra/operator_tests.py b/nifty4/extra/operator_tests.py
index 258c2a877..15d2b9109 100644
--- a/nifty4/extra/operator_tests.py
+++ b/nifty4/extra/operator_tests.py
@@ -17,7 +17,7 @@
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
 import numpy as np
-from ..field import Field
+from ..sugar import from_random
 
 __all__ = ["consistency_check"]
 
@@ -26,8 +26,8 @@ def adjoint_implementation(op, domain_dtype, target_dtype, atol, rtol):
     needed_cap = op.TIMES | op.ADJOINT_TIMES
     if (op.capability & needed_cap) != needed_cap:
         return
-    f1 = Field.from_random("normal", op.domain, dtype=domain_dtype).lock()
-    f2 = Field.from_random("normal", op.target, dtype=target_dtype).lock()
+    f1 = from_random("normal", op.domain, dtype=domain_dtype).lock()
+    f2 = from_random("normal", op.target, dtype=target_dtype).lock()
     res1 = f1.vdot(op.adjoint_times(f2).lock())
     res2 = op.times(f1).vdot(f2)
     np.testing.assert_allclose(res1, res2, atol=atol, rtol=rtol)
@@ -37,12 +37,12 @@ def inverse_implementation(op, domain_dtype, target_dtype, atol, rtol):
     needed_cap = op.TIMES | op.INVERSE_TIMES
     if (op.capability & needed_cap) != needed_cap:
         return
-    foo = Field.from_random("normal", op.target, dtype=target_dtype).lock()
+    foo = from_random("normal", op.target, dtype=target_dtype).lock()
     res = op(op.inverse_times(foo).lock())
     np.testing.assert_allclose(res.to_global_data(), res.to_global_data(),
                                atol=atol, rtol=rtol)
 
-    foo = Field.from_random("normal", op.domain, dtype=domain_dtype).lock()
+    foo = from_random("normal", op.domain, dtype=domain_dtype).lock()
     res = op.inverse_times(op(foo).lock())
     np.testing.assert_allclose(res.to_global_data(), foo.to_global_data(),
                                atol=atol, rtol=rtol)
diff --git a/nifty4/field.py b/nifty4/field.py
index 289faf237..c30e01a13 100644
--- a/nifty4/field.py
+++ b/nifty4/field.py
@@ -106,62 +106,10 @@ class Field(object):
             raise TypeError("val must be a scalar")
         return Field(DomainTuple.make(domain), val, dtype)
 
-    @staticmethod
-    def ones(domain, dtype=None):
-        return Field(DomainTuple.make(domain), 1., dtype)
-
-    @staticmethod
-    def zeros(domain, dtype=None):
-        return Field(DomainTuple.make(domain), 0., dtype)
-
     @staticmethod
     def empty(domain, dtype=None):
         return Field(DomainTuple.make(domain), None, dtype)
 
-    @staticmethod
-    def full_like(field, val, dtype=None):
-        """Creates a Field from a template, filled with a constant value.
-
-        Parameters
-        ----------
-        field : Field
-            the template field, from which the domain is inferred
-        val : float/complex/int scalar
-            fill value. Data type of the field is inferred from val.
-
-        Returns
-        -------
-        Field
-            the newly created field
-        """
-        if not isinstance(field, Field):
-            raise TypeError("field must be of Field type")
-        return Field.full(field._domain, val, dtype)
-
-    @staticmethod
-    def zeros_like(field, dtype=None):
-        if not isinstance(field, Field):
-            raise TypeError("field must be of Field type")
-        if dtype is None:
-            dtype = field.dtype
-        return Field.zeros(field._domain, dtype)
-
-    @staticmethod
-    def ones_like(field, dtype=None):
-        if not isinstance(field, Field):
-            raise TypeError("field must be of Field type")
-        if dtype is None:
-            dtype = field.dtype
-        return Field.ones(field._domain, dtype)
-
-    @staticmethod
-    def empty_like(field, dtype=None):
-        if not isinstance(field, Field):
-            raise TypeError("field must be of Field type")
-        if dtype is None:
-            dtype = field.dtype
-        return Field.empty(field._domain, dtype)
-
     @staticmethod
     def from_global_data(domain, arr, sum_up=False):
         """Returns a Field constructed from `domain` and `arr`.
diff --git a/nifty4/library/nonlinearities.py b/nifty4/library/nonlinearities.py
index 810054d0e..ab5a707a7 100644
--- a/nifty4/library/nonlinearities.py
+++ b/nifty4/library/nonlinearities.py
@@ -16,7 +16,8 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
-from ..field import Field, exp, tanh
+from ..field import exp, tanh
+from ..sugar import full
 
 
 class Linear(object):
@@ -24,10 +25,10 @@ class Linear(object):
         return x
 
     def derivative(self, x):
-        return Field.ones_like(x)
+        return full(x.domain, 1.)
 
     def hessian(self, x):
-        return Field.zeros_like(x)
+        return full(x.domain, 0.)
 
 
 class Exponential(object):
diff --git a/nifty4/multi/multi_domain.py b/nifty4/multi/multi_domain.py
index 27c5f841d..c77b642af 100644
--- a/nifty4/multi/multi_domain.py
+++ b/nifty4/multi/multi_domain.py
@@ -1,2 +1,73 @@
-class MultiDomain(dict):
-    pass
+import collections
+from ..domain_tuple import DomainTuple
+
+__all = ["MultiDomain"]
+
+
+class frozendict(collections.Mapping):
+    """
+    An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping`
+    interface. It can be used as a drop-in replacement for dictionaries where immutability is desired.
+    """
+
+    dict_cls = dict
+
+    def __init__(self, *args, **kwargs):
+        self._dict = self.dict_cls(*args, **kwargs)
+        self._hash = None
+
+    def __getitem__(self, key):
+        return self._dict[key]
+
+    def __contains__(self, key):
+        return key in self._dict
+
+    def copy(self, **add_or_replace):
+        return self.__class__(self, **add_or_replace)
+
+    def __iter__(self):
+        return iter(self._dict)
+
+    def __len__(self):
+        return len(self._dict)
+
+    def __repr__(self):
+        return '<%s %r>' % (self.__class__.__name__, self._dict)
+
+    def __hash__(self):
+        if self._hash is None:
+            h = 0
+            for key, value in self._dict.items():
+                h ^= hash((key, value))
+            self._hash = h
+        return self._hash
+
+
+class MultiDomain(frozendict):
+    _domainCache = {}
+
+    def __init__(domain, _callingfrommake=False):
+        if not _callingfrommake:
+            raise NotImplementedError
+        super(MultiDomain, self).__init__(domain)
+
+    @staticmethod
+    def make(domain):
+        if isinstance(domain, MultiDomain):
+            return domain
+        print type(domain)
+        if not isinstance(domain, dict):
+            raise TypeError("dict expected")
+        tmp = {}
+        for key, value in domain.items():
+            if not isinstance(key, str):
+                raise TypeError("keys must be strings")
+            tmp[key] = DomainTuple.make(value)
+        domain = frozendict(tmp)
+        print tmp
+        obj = MultiDomain._domainCache.get(domain)
+        if obj is not None:
+            return obj
+        obj = MultiDomain(domain, _callingfrommake=True)
+        MultiDomain._domainCache[domain] = obj
+        return obj
diff --git a/nifty4/multi/multi_field.py b/nifty4/multi/multi_field.py
index 7bcf821bf..454e5c424 100644
--- a/nifty4/multi/multi_field.py
+++ b/nifty4/multi/multi_field.py
@@ -85,21 +85,14 @@ class MultiField(object):
         return {key: dtype for key in domain.keys()}
 
     @staticmethod
-    def zeros(domain, dtype=None):
-        dtype = MultiField.build_dtype(dtype, domain)
-        return MultiField({key: Field.zeros(dom, dtype=dtype[key])
-                           for key, dom in domain.items()})
-
-    @staticmethod
-    def ones(domain, dtype=None):
+    def empty(domain, dtype=None):
         dtype = MultiField.build_dtype(dtype, domain)
-        return MultiField({key: Field.ones(dom, dtype=dtype[key])
+        return MultiField({key: Field.empty(dom, dtype=dtype[key])
                            for key, dom in domain.items()})
 
     @staticmethod
-    def empty(domain, dtype=None):
-        dtype = MultiField.build_dtype(dtype, domain)
-        return MultiField({key: Field.empty(dom, dtype=dtype[key])
+    def full(domain, val):
+        return MultiField({key: Field.full(dom, val)
                            for key, dom in domain.items()})
 
     def norm(self):
diff --git a/nifty4/operators/linear_operator.py b/nifty4/operators/linear_operator.py
index c7d5f2296..185c5049b 100644
--- a/nifty4/operators/linear_operator.py
+++ b/nifty4/operators/linear_operator.py
@@ -271,10 +271,6 @@ class LinearOperator(NiftyMetaBase()):
             raise ValueError("requested operator mode is not supported")
 
     def _check_input(self, x, mode):
-        # MR FIXME: temporary fix for working with MultiFields
-        #if not isinstance(x, Field):
-        #    raise ValueError("supplied object is not a `Field`.")
-
         self._check_mode(mode)
         if x.domain != self._dom(mode):
             raise ValueError("The operator's and field's domains don't match.")
diff --git a/nifty4/operators/scaling_operator.py b/nifty4/operators/scaling_operator.py
index b9a26bfaf..20066cb7d 100644
--- a/nifty4/operators/scaling_operator.py
+++ b/nifty4/operators/scaling_operator.py
@@ -62,7 +62,7 @@ class ScalingOperator(EndomorphicOperator):
         if self._factor == 1.:
             return x.copy()
         if self._factor == 0.:
-            return x.zeros_like(x)
+            return x*0.
 
         if mode == self.TIMES:
             return x*self._factor
diff --git a/nifty4/sugar.py b/nifty4/sugar.py
index 3e829e3ed..3e90bfd2b 100644
--- a/nifty4/sugar.py
+++ b/nifty4/sugar.py
@@ -19,16 +19,18 @@
 import numpy as np
 from .domains.power_space import PowerSpace
 from .field import Field
+from multi.multi_field import MultiField
+from multi.multi_domain import MultiDomain
 from .operators.diagonal_operator import DiagonalOperator
 from .operators.power_distributor import PowerDistributor
 from .domain_tuple import DomainTuple
 from . import dobj, utilities
 from .logger import logger
 
-__all__ = ['PS_field',
-           'power_analyze',
-           'create_power_operator',
-           'create_harmonic_smoothing_operator']
+__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
+           'create_harmonic_smoothing_operator', 'from_random',
+           'full', 'empty', 'from_global_data', 'from_local_data',
+           'makeDomain']
 
 
 def PS_field(pspace, func):
@@ -161,3 +163,39 @@ def create_harmonic_smoothing_operator(domain, space, sigma):
     kfunc = domain[space].get_fft_smoothing_kernel_function(sigma)
     return DiagonalOperator(kfunc(domain[space].get_k_length_array()), domain,
                             space)
+
+
+def full(domain, val):
+    if isinstance(domain, (dict, MultiDomain)):
+        return MultiField.full(domain, val)
+    return Field.full(domain, val)
+
+
+def empty(domain, dtype):
+    if isinstance(domain, (dict, MultiDomain)):
+        return MultiField.empty(domain, dtype)
+    return Field.empty(domain, dtype)
+
+
+def from_random(random_type, domain, dtype=np.float64, **kwargs):
+    if isinstance(domain, (dict, MultiDomain)):
+        return MultiField.from_random(random_type, domain, dtype, **kwargs)
+    return Field.from_random(random_type, domain, dtype, **kwargs)
+
+
+def from_global_data(domain, arr, sum_up=False):
+    if isinstance(domain, (dict, MultiDomain)):
+        return MultiField.from_global_data(domain, arr, sum_up)
+    return Field.from_global_data(domain, arr, sum_up)
+
+
+def from_local_data(domain, arr):
+    if isinstance(domain, (dict, MultiDomain)):
+        return MultiField.from_local_data(domain, arr)
+    return Field.from_local_data(domain, arr)
+
+
+def makeDomain(domain):
+    if isinstance(domain, dict):
+        return MultiDomain.make(domain)
+    return DomainTuple.make(domain)
diff --git a/test/test_energies/test_power.py b/test/test_energies/test_power.py
index 1062f65fc..a8c01f04b 100644
--- a/test/test_energies/test_power.py
+++ b/test/test_energies/test_power.py
@@ -50,7 +50,7 @@ class Energy_Tests(unittest.TestCase):
         n = ift.Field.from_random(domain=space, random_type='normal')
         s = ht(xi * A)
         R = ift.ScalingOperator(10., space)
-        diag = ift.Field.ones(space)
+        diag = ift.full(space, 1.)
         N = ift.DiagonalOperator(diag)
         d = R(f(s)) + n
 
diff --git a/test/test_field.py b/test/test_field.py
index 5515a1d41..1539b317b 100644
--- a/test/test_field.py
+++ b/test/test_field.py
@@ -130,18 +130,18 @@ class Test_Functionality(unittest.TestCase):
         assert_equal(f.local_data, 27)
         assert_equal(f.shape, (200,))
         assert_equal(f.dtype, np.int)
-        fx = ift.Field.empty_like(f)
+        fx = ift.empty(f.domain, f.dtype)
         assert_equal(f.dtype, fx.dtype)
         assert_equal(f.shape, fx.shape)
-        fx = ift.Field.zeros_like(f)
+        fx = ift.full(f.domain, 0)
         assert_equal(f.dtype, fx.dtype)
         assert_equal(f.shape, fx.shape)
         assert_equal(fx.local_data, 0)
-        fx = ift.Field.ones_like(f)
+        fx = ift.full(f.domain, 1)
         assert_equal(f.dtype, fx.dtype)
         assert_equal(f.shape, fx.shape)
         assert_equal(fx.local_data, 1)
-        fx = ift.Field.full_like(f, 67.)
+        fx = ift.full(f.domain, 67.)
         assert_equal(f.shape, fx.shape)
         assert_equal(fx.local_data, 67.)
         f = ift.Field.from_random("normal", s)
diff --git a/test/test_minimization/test_minimizers.py b/test/test_minimization/test_minimizers.py
index 140479ae5..7e8d3ef31 100644
--- a/test/test_minimization/test_minimizers.py
+++ b/test/test_minimization/test_minimizers.py
@@ -53,7 +53,7 @@ class Test_Minimizers(unittest.TestCase):
         covariance_diagonal = ift.Field.from_random(
                                   'uniform', domain=space) + 0.5
         covariance = ift.DiagonalOperator(covariance_diagonal)
-        required_result = ift.Field.ones(space, dtype=np.float64)
+        required_result = ift.full(space, 1.)
 
         try:
             minimizer = eval(minimizer)
-- 
GitLab