From 77c29cf3406bf53564579af26a98bca157472160 Mon Sep 17 00:00:00 2001
From: Martin Reinecke <martin@mpa-garching.mpg.de>
Date: Mon, 21 May 2018 11:14:19 +0200
Subject: [PATCH] move arithmetic functions working on Fields to sugar

---
 nifty4/__init__.py                       |  2 +-
 nifty4/domains/lm_space.py               |  4 ++-
 nifty4/domains/rg_space.py               |  3 ++-
 nifty4/field.py                          | 21 ---------------
 nifty4/library/noise_energy.py           |  3 ++-
 nifty4/library/nonlinear_power_energy.py |  2 +-
 nifty4/library/nonlinearities.py         |  3 +--
 nifty4/library/poisson_energy.py         |  2 +-
 nifty4/multi/multi_domain.py             |  5 ++--
 nifty4/sugar.py                          | 34 +++++++++++++++++++++++-
 10 files changed, 47 insertions(+), 32 deletions(-)

diff --git a/nifty4/__init__.py b/nifty4/__init__.py
index d74654389..2f1de367b 100644
--- a/nifty4/__init__.py
+++ b/nifty4/__init__.py
@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple
 
 from .operators import *
 
-from .field import Field, sqrt, exp, log
+from .field import Field
 
 from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
     StatCalculator
diff --git a/nifty4/domains/lm_space.py b/nifty4/domains/lm_space.py
index 721171a11..91968a9df 100644
--- a/nifty4/domains/lm_space.py
+++ b/nifty4/domains/lm_space.py
@@ -19,7 +19,7 @@
 from __future__ import division
 import numpy as np
 from .structured_domain import StructuredDomain
-from ..field import Field, exp
+from ..field import Field
 
 
 class LMSpace(StructuredDomain):
@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain):
         # cf. "All-sky convolution for polarimetry experiments"
         # by Challinor et al.
         # http://arxiv.org/abs/astro-ph/0008228
+        from ..sugar import exp
+
         res = x+1.
         res *= x
         res *= -0.5*sigma*sigma
diff --git a/nifty4/domains/rg_space.py b/nifty4/domains/rg_space.py
index 5f8e0a1fd..118ee47db 100644
--- a/nifty4/domains/rg_space.py
+++ b/nifty4/domains/rg_space.py
@@ -21,7 +21,7 @@ from builtins import range
 from functools import reduce
 import numpy as np
 from .structured_domain import StructuredDomain
-from ..field import Field, exp
+from ..field import Field
 from .. import dobj
 
 
@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain):
 
     @staticmethod
     def _kernel(x, sigma):
+        from ..sugar import exp
         tmp = x*x
         tmp *= -2.*np.pi*np.pi*sigma*sigma
         exp(tmp, out=tmp)
diff --git a/nifty4/field.py b/nifty4/field.py
index c30e01a13..62c74808f 100644
--- a/nifty4/field.py
+++ b/nifty4/field.py
@@ -733,24 +733,3 @@ for op in ["__add__", "__radd__", "__iadd__",
             return NotImplemented
         return func2
     setattr(Field, op, func(op))
-
-
-# Arithmetic functions working on Fields
-
-_current_module = sys.modules[__name__]
-
-for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
-    def func(f):
-        def func2(x, out=None):
-            fu = getattr(dobj, f)
-            if not isinstance(x, Field):
-                raise TypeError("This function only accepts Field objects.")
-            if out is not None:
-                if not isinstance(out, Field) or x._domain != out._domain:
-                    raise ValueError("Bad 'out' argument")
-                fu(x.val, out=out.val)
-                return out
-            else:
-                return Field(domain=x._domain, val=fu(x.val))
-        return func2
-    setattr(_current_module, f, func(f))
diff --git a/nifty4/library/noise_energy.py b/nifty4/library/noise_energy.py
index cfad61cfc..dc15146c0 100644
--- a/nifty4/library/noise_energy.py
+++ b/nifty4/library/noise_energy.py
@@ -16,7 +16,8 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
-from ..field import Field, exp
+from ..field import Field
+from ..sugar import exp
 from ..minimization.energy import Energy
 from ..operators.diagonal_operator import DiagonalOperator
 import numpy as np
diff --git a/nifty4/library/nonlinear_power_energy.py b/nifty4/library/nonlinear_power_energy.py
index e76defdd6..dff636a79 100644
--- a/nifty4/library/nonlinear_power_energy.py
+++ b/nifty4/library/nonlinear_power_energy.py
@@ -16,7 +16,7 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
-from .. import exp
+from ..sugar import exp
 from ..minimization.energy import Energy
 from ..operators.smoothness_operator import SmoothnessOperator
 from ..operators.inversion_enabler import InversionEnabler
diff --git a/nifty4/library/nonlinearities.py b/nifty4/library/nonlinearities.py
index ab5a707a7..648290c8b 100644
--- a/nifty4/library/nonlinearities.py
+++ b/nifty4/library/nonlinearities.py
@@ -16,8 +16,7 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
-from ..field import exp, tanh
-from ..sugar import full
+from ..sugar import full, exp, tanh
 
 
 class Linear(object):
diff --git a/nifty4/library/poisson_energy.py b/nifty4/library/poisson_energy.py
index 637cd28f8..f5cbdf0a1 100644
--- a/nifty4/library/poisson_energy.py
+++ b/nifty4/library/poisson_energy.py
@@ -20,7 +20,7 @@ from ..minimization.energy import Energy
 from ..operators.diagonal_operator import DiagonalOperator
 from ..operators.sandwich_operator import SandwichOperator
 from ..operators.inversion_enabler import InversionEnabler
-from ..field import log
+from ..sugar import log
 
 
 class PoissonEnergy(Energy):
diff --git a/nifty4/multi/multi_domain.py b/nifty4/multi/multi_domain.py
index 619a150da..2f860d89b 100644
--- a/nifty4/multi/multi_domain.py
+++ b/nifty4/multi/multi_domain.py
@@ -6,8 +6,9 @@ __all = ["MultiDomain"]
 
 class frozendict(collections.Mapping):
     """
-    An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping`
-    interface. It can be used as a drop-in replacement for dictionaries where immutability is desired.
+    An immutable wrapper around dictionaries that implements the complete
+    :py:class:`collections.Mapping` interface. It can be used as a drop-in
+    replacement for dictionaries where immutability is desired.
     """
 
     dict_cls = dict
diff --git a/nifty4/sugar.py b/nifty4/sugar.py
index 17cf0fd00..f4bb124cf 100644
--- a/nifty4/sugar.py
+++ b/nifty4/sugar.py
@@ -16,6 +16,7 @@
 # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
 # and financially supported by the Studienstiftung des deutschen Volkes.
 
+import sys
 import numpy as np
 from .domains.power_space import PowerSpace
 from .field import Field
@@ -30,7 +31,7 @@ from .logger import logger
 __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
            'create_harmonic_smoothing_operator', 'from_random',
            'full', 'empty', 'from_global_data', 'from_local_data',
-           'makeDomain']
+           'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate']
 
 
 def PS_field(pspace, func):
@@ -199,3 +200,34 @@ def makeDomain(domain):
     if isinstance(domain, dict):
         return MultiDomain.make(domain)
     return DomainTuple.make(domain)
+
+
+# Arithmetic functions working on Fields
+
+_current_module = sys.modules[__name__]
+
+for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
+    def func(f):
+        def func2(x, out=None):
+            if isinstance(x, MultiField):
+                if out is not None:
+                    if (not isinstance(out, MultiField) or
+                            x._domain != out._domain):
+                        raise ValueError("Bad 'out' argument")
+                    for key, value in x.items():
+                        func2(value, out=out[key])
+                    return out
+                return MultiField({key: func2(val) for key, val in x.items()})
+
+            if not isinstance(x, Field):
+                raise TypeError("This function only accepts Field objects.")
+            fu = getattr(dobj, f)
+            if out is not None:
+                if not isinstance(out, Field) or x._domain != out._domain:
+                    raise ValueError("Bad 'out' argument")
+                fu(x.val, out=out.val)
+                return out
+            else:
+                return Field(domain=x._domain, val=fu(x.val))
+        return func2
+    setattr(_current_module, f, func(f))
-- 
GitLab