diff --git a/nifty4/__init__.py b/nifty4/__init__.py index d74654389083ea70f1f1407a3faa07184bf1b13c..2f1de367bdd911dec158ce18d0998931a6667c7d 100644 --- a/nifty4/__init__.py +++ b/nifty4/__init__.py @@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple from .operators import * -from .field import Field, sqrt, exp, log +from .field import Field from .probing.utils import probe_with_posterior_samples, probe_diagonal, \ StatCalculator diff --git a/nifty4/domains/lm_space.py b/nifty4/domains/lm_space.py index 721171a110a989b9b7f593055c3768617374ce8f..91968a9df0f9b95a9861a6dffc6ca01b033baf53 100644 --- a/nifty4/domains/lm_space.py +++ b/nifty4/domains/lm_space.py @@ -19,7 +19,7 @@ from __future__ import division import numpy as np from .structured_domain import StructuredDomain -from ..field import Field, exp +from ..field import Field class LMSpace(StructuredDomain): @@ -100,6 +100,8 @@ class LMSpace(StructuredDomain): # cf. "All-sky convolution for polarimetry experiments" # by Challinor et al. # http://arxiv.org/abs/astro-ph/0008228 + from ..sugar import exp + res = x+1. res *= x res *= -0.5*sigma*sigma diff --git a/nifty4/domains/rg_space.py b/nifty4/domains/rg_space.py index 5f8e0a1fdfd046a23bf7426bbb748ca295ba1176..118ee47dbeb95ffd2aefa5811ac3e8ea34220496 100644 --- a/nifty4/domains/rg_space.py +++ b/nifty4/domains/rg_space.py @@ -21,7 +21,7 @@ from builtins import range from functools import reduce import numpy as np from .structured_domain import StructuredDomain -from ..field import Field, exp +from ..field import Field from .. import dobj @@ -144,6 +144,7 @@ class RGSpace(StructuredDomain): @staticmethod def _kernel(x, sigma): + from ..sugar import exp tmp = x*x tmp *= -2.*np.pi*np.pi*sigma*sigma exp(tmp, out=tmp) diff --git a/nifty4/field.py b/nifty4/field.py index c30e01a13b1bd1294d1ed62c75bffbcbed8efbfd..62c74808f5ab9f06d2bfa31af0d058deed089d62 100644 --- a/nifty4/field.py +++ b/nifty4/field.py @@ -733,24 +733,3 @@ for op in ["__add__", "__radd__", "__iadd__", return NotImplemented return func2 setattr(Field, op, func(op)) - - -# Arithmetic functions working on Fields - -_current_module = sys.modules[__name__] - -for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: - def func(f): - def func2(x, out=None): - fu = getattr(dobj, f) - if not isinstance(x, Field): - raise TypeError("This function only accepts Field objects.") - if out is not None: - if not isinstance(out, Field) or x._domain != out._domain: - raise ValueError("Bad 'out' argument") - fu(x.val, out=out.val) - return out - else: - return Field(domain=x._domain, val=fu(x.val)) - return func2 - setattr(_current_module, f, func(f)) diff --git a/nifty4/library/noise_energy.py b/nifty4/library/noise_energy.py index cfad61cfcd7a72e2a3330363c82b7c5073d6c9e4..dc15146c08b7226c05a8370f22770709189afa6d 100644 --- a/nifty4/library/noise_energy.py +++ b/nifty4/library/noise_energy.py @@ -16,7 +16,8 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from ..field import Field, exp +from ..field import Field +from ..sugar import exp from ..minimization.energy import Energy from ..operators.diagonal_operator import DiagonalOperator import numpy as np diff --git a/nifty4/library/nonlinear_power_energy.py b/nifty4/library/nonlinear_power_energy.py index e76defdd6121e8f6f2d28efc63ff91155bbf2936..dff636a7934983cb3c494b38bcec5e5351223770 100644 --- a/nifty4/library/nonlinear_power_energy.py +++ b/nifty4/library/nonlinear_power_energy.py @@ -16,7 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from .. import exp +from ..sugar import exp from ..minimization.energy import Energy from ..operators.smoothness_operator import SmoothnessOperator from ..operators.inversion_enabler import InversionEnabler diff --git a/nifty4/library/nonlinearities.py b/nifty4/library/nonlinearities.py index ab5a707a70c786418165ed0c51afea8498d0247e..648290c8b2e949928565ae618a77a78c74ea11e2 100644 --- a/nifty4/library/nonlinearities.py +++ b/nifty4/library/nonlinearities.py @@ -16,8 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from ..field import exp, tanh -from ..sugar import full +from ..sugar import full, exp, tanh class Linear(object): diff --git a/nifty4/library/poisson_energy.py b/nifty4/library/poisson_energy.py index 637cd28f81348240e68ea6c6cd32a3338ed5b48d..f5cbdf0a1096e380be9d9be64ac0c25699a503db 100644 --- a/nifty4/library/poisson_energy.py +++ b/nifty4/library/poisson_energy.py @@ -20,7 +20,7 @@ from ..minimization.energy import Energy from ..operators.diagonal_operator import DiagonalOperator from ..operators.sandwich_operator import SandwichOperator from ..operators.inversion_enabler import InversionEnabler -from ..field import log +from ..sugar import log class PoissonEnergy(Energy): diff --git a/nifty4/multi/multi_domain.py b/nifty4/multi/multi_domain.py index 619a150daea88b85f25bc30b942ecf3a2ea6f606..2f860d89b2d0f583e82d57cf06c31546956b4d4e 100644 --- a/nifty4/multi/multi_domain.py +++ b/nifty4/multi/multi_domain.py @@ -6,8 +6,9 @@ __all = ["MultiDomain"] class frozendict(collections.Mapping): """ - An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping` - interface. It can be used as a drop-in replacement for dictionaries where immutability is desired. + An immutable wrapper around dictionaries that implements the complete + :py:class:`collections.Mapping` interface. It can be used as a drop-in + replacement for dictionaries where immutability is desired. """ dict_cls = dict diff --git a/nifty4/sugar.py b/nifty4/sugar.py index 17cf0fd009704251dd027ec58e4b1b8db9b0ef4b..f4bb124cfb4fd2a2aed9e40fe62e7699d00b7804 100644 --- a/nifty4/sugar.py +++ b/nifty4/sugar.py @@ -16,6 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. +import sys import numpy as np from .domains.power_space import PowerSpace from .field import Field @@ -30,7 +31,7 @@ from .logger import logger __all__ = ['PS_field', 'power_analyze', 'create_power_operator', 'create_harmonic_smoothing_operator', 'from_random', 'full', 'empty', 'from_global_data', 'from_local_data', - 'makeDomain'] + 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate'] def PS_field(pspace, func): @@ -199,3 +200,34 @@ def makeDomain(domain): if isinstance(domain, dict): return MultiDomain.make(domain) return DomainTuple.make(domain) + + +# Arithmetic functions working on Fields + +_current_module = sys.modules[__name__] + +for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: + def func(f): + def func2(x, out=None): + if isinstance(x, MultiField): + if out is not None: + if (not isinstance(out, MultiField) or + x._domain != out._domain): + raise ValueError("Bad 'out' argument") + for key, value in x.items(): + func2(value, out=out[key]) + return out + return MultiField({key: func2(val) for key, val in x.items()}) + + if not isinstance(x, Field): + raise TypeError("This function only accepts Field objects.") + fu = getattr(dobj, f) + if out is not None: + if not isinstance(out, Field) or x._domain != out._domain: + raise ValueError("Bad 'out' argument") + fu(x.val, out=out.val) + return out + else: + return Field(domain=x._domain, val=fu(x.val)) + return func2 + setattr(_current_module, f, func(f))