From 77c29cf3406bf53564579af26a98bca157472160 Mon Sep 17 00:00:00 2001 From: Martin Reinecke <martin@mpa-garching.mpg.de> Date: Mon, 21 May 2018 11:14:19 +0200 Subject: [PATCH] move arithmetic functions working on Fields to sugar --- nifty4/__init__.py | 2 +- nifty4/domains/lm_space.py | 4 ++- nifty4/domains/rg_space.py | 3 ++- nifty4/field.py | 21 --------------- nifty4/library/noise_energy.py | 3 ++- nifty4/library/nonlinear_power_energy.py | 2 +- nifty4/library/nonlinearities.py | 3 +-- nifty4/library/poisson_energy.py | 2 +- nifty4/multi/multi_domain.py | 5 ++-- nifty4/sugar.py | 34 +++++++++++++++++++++++- 10 files changed, 47 insertions(+), 32 deletions(-) diff --git a/nifty4/__init__.py b/nifty4/__init__.py index d74654389..2f1de367b 100644 --- a/nifty4/__init__.py +++ b/nifty4/__init__.py @@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple from .operators import * -from .field import Field, sqrt, exp, log +from .field import Field from .probing.utils import probe_with_posterior_samples, probe_diagonal, \ StatCalculator diff --git a/nifty4/domains/lm_space.py b/nifty4/domains/lm_space.py index 721171a11..91968a9df 100644 --- a/nifty4/domains/lm_space.py +++ b/nifty4/domains/lm_space.py @@ -19,7 +19,7 @@ from __future__ import division import numpy as np from .structured_domain import StructuredDomain -from ..field import Field, exp +from ..field import Field class LMSpace(StructuredDomain): @@ -100,6 +100,8 @@ class LMSpace(StructuredDomain): # cf. "All-sky convolution for polarimetry experiments" # by Challinor et al. # http://arxiv.org/abs/astro-ph/0008228 + from ..sugar import exp + res = x+1. res *= x res *= -0.5*sigma*sigma diff --git a/nifty4/domains/rg_space.py b/nifty4/domains/rg_space.py index 5f8e0a1fd..118ee47db 100644 --- a/nifty4/domains/rg_space.py +++ b/nifty4/domains/rg_space.py @@ -21,7 +21,7 @@ from builtins import range from functools import reduce import numpy as np from .structured_domain import StructuredDomain -from ..field import Field, exp +from ..field import Field from .. import dobj @@ -144,6 +144,7 @@ class RGSpace(StructuredDomain): @staticmethod def _kernel(x, sigma): + from ..sugar import exp tmp = x*x tmp *= -2.*np.pi*np.pi*sigma*sigma exp(tmp, out=tmp) diff --git a/nifty4/field.py b/nifty4/field.py index c30e01a13..62c74808f 100644 --- a/nifty4/field.py +++ b/nifty4/field.py @@ -733,24 +733,3 @@ for op in ["__add__", "__radd__", "__iadd__", return NotImplemented return func2 setattr(Field, op, func(op)) - - -# Arithmetic functions working on Fields - -_current_module = sys.modules[__name__] - -for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: - def func(f): - def func2(x, out=None): - fu = getattr(dobj, f) - if not isinstance(x, Field): - raise TypeError("This function only accepts Field objects.") - if out is not None: - if not isinstance(out, Field) or x._domain != out._domain: - raise ValueError("Bad 'out' argument") - fu(x.val, out=out.val) - return out - else: - return Field(domain=x._domain, val=fu(x.val)) - return func2 - setattr(_current_module, f, func(f)) diff --git a/nifty4/library/noise_energy.py b/nifty4/library/noise_energy.py index cfad61cfc..dc15146c0 100644 --- a/nifty4/library/noise_energy.py +++ b/nifty4/library/noise_energy.py @@ -16,7 +16,8 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from ..field import Field, exp +from ..field import Field +from ..sugar import exp from ..minimization.energy import Energy from ..operators.diagonal_operator import DiagonalOperator import numpy as np diff --git a/nifty4/library/nonlinear_power_energy.py b/nifty4/library/nonlinear_power_energy.py index e76defdd6..dff636a79 100644 --- a/nifty4/library/nonlinear_power_energy.py +++ b/nifty4/library/nonlinear_power_energy.py @@ -16,7 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from .. import exp +from ..sugar import exp from ..minimization.energy import Energy from ..operators.smoothness_operator import SmoothnessOperator from ..operators.inversion_enabler import InversionEnabler diff --git a/nifty4/library/nonlinearities.py b/nifty4/library/nonlinearities.py index ab5a707a7..648290c8b 100644 --- a/nifty4/library/nonlinearities.py +++ b/nifty4/library/nonlinearities.py @@ -16,8 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from ..field import exp, tanh -from ..sugar import full +from ..sugar import full, exp, tanh class Linear(object): diff --git a/nifty4/library/poisson_energy.py b/nifty4/library/poisson_energy.py index 637cd28f8..f5cbdf0a1 100644 --- a/nifty4/library/poisson_energy.py +++ b/nifty4/library/poisson_energy.py @@ -20,7 +20,7 @@ from ..minimization.energy import Energy from ..operators.diagonal_operator import DiagonalOperator from ..operators.sandwich_operator import SandwichOperator from ..operators.inversion_enabler import InversionEnabler -from ..field import log +from ..sugar import log class PoissonEnergy(Energy): diff --git a/nifty4/multi/multi_domain.py b/nifty4/multi/multi_domain.py index 619a150da..2f860d89b 100644 --- a/nifty4/multi/multi_domain.py +++ b/nifty4/multi/multi_domain.py @@ -6,8 +6,9 @@ __all = ["MultiDomain"] class frozendict(collections.Mapping): """ - An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping` - interface. It can be used as a drop-in replacement for dictionaries where immutability is desired. + An immutable wrapper around dictionaries that implements the complete + :py:class:`collections.Mapping` interface. It can be used as a drop-in + replacement for dictionaries where immutability is desired. """ dict_cls = dict diff --git a/nifty4/sugar.py b/nifty4/sugar.py index 17cf0fd00..f4bb124cf 100644 --- a/nifty4/sugar.py +++ b/nifty4/sugar.py @@ -16,6 +16,7 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. +import sys import numpy as np from .domains.power_space import PowerSpace from .field import Field @@ -30,7 +31,7 @@ from .logger import logger __all__ = ['PS_field', 'power_analyze', 'create_power_operator', 'create_harmonic_smoothing_operator', 'from_random', 'full', 'empty', 'from_global_data', 'from_local_data', - 'makeDomain'] + 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate'] def PS_field(pspace, func): @@ -199,3 +200,34 @@ def makeDomain(domain): if isinstance(domain, dict): return MultiDomain.make(domain) return DomainTuple.make(domain) + + +# Arithmetic functions working on Fields + +_current_module = sys.modules[__name__] + +for f in ["sqrt", "exp", "log", "tanh", "conjugate"]: + def func(f): + def func2(x, out=None): + if isinstance(x, MultiField): + if out is not None: + if (not isinstance(out, MultiField) or + x._domain != out._domain): + raise ValueError("Bad 'out' argument") + for key, value in x.items(): + func2(value, out=out[key]) + return out + return MultiField({key: func2(val) for key, val in x.items()}) + + if not isinstance(x, Field): + raise TypeError("This function only accepts Field objects.") + fu = getattr(dobj, f) + if out is not None: + if not isinstance(out, Field) or x._domain != out._domain: + raise ValueError("Bad 'out' argument") + fu(x.val, out=out.val) + return out + else: + return Field(domain=x._domain, val=fu(x.val)) + return func2 + setattr(_current_module, f, func(f)) -- GitLab