Commit 77c29cf3 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

move arithmetic functions working on Fields to sugar

parent bb07fd25
Pipeline #29510 passed with stages
in 3 minutes and 59 seconds
...@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple ...@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple
from .operators import * from .operators import *
from .field import Field, sqrt, exp, log from .field import Field
from .probing.utils import probe_with_posterior_samples, probe_diagonal, \ from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
StatCalculator StatCalculator
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from .structured_domain import StructuredDomain from .structured_domain import StructuredDomain
from ..field import Field, exp from ..field import Field
class LMSpace(StructuredDomain): class LMSpace(StructuredDomain):
...@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain): ...@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain):
# cf. "All-sky convolution for polarimetry experiments" # cf. "All-sky convolution for polarimetry experiments"
# by Challinor et al. # by Challinor et al.
# http://arxiv.org/abs/astro-ph/0008228 # http://arxiv.org/abs/astro-ph/0008228
from ..sugar import exp
res = x+1. res = x+1.
res *= x res *= x
res *= -0.5*sigma*sigma res *= -0.5*sigma*sigma
......
...@@ -21,7 +21,7 @@ from builtins import range ...@@ -21,7 +21,7 @@ from builtins import range
from functools import reduce from functools import reduce
import numpy as np import numpy as np
from .structured_domain import StructuredDomain from .structured_domain import StructuredDomain
from ..field import Field, exp from ..field import Field
from .. import dobj from .. import dobj
...@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain): ...@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain):
@staticmethod @staticmethod
def _kernel(x, sigma): def _kernel(x, sigma):
from ..sugar import exp
tmp = x*x tmp = x*x
tmp *= -2.*np.pi*np.pi*sigma*sigma tmp *= -2.*np.pi*np.pi*sigma*sigma
exp(tmp, out=tmp) exp(tmp, out=tmp)
......
...@@ -733,24 +733,3 @@ for op in ["__add__", "__radd__", "__iadd__", ...@@ -733,24 +733,3 @@ for op in ["__add__", "__radd__", "__iadd__",
return NotImplemented return NotImplemented
return func2 return func2
setattr(Field, op, func(op)) setattr(Field, op, func(op))
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
fu = getattr(dobj, f)
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=fu(x.val))
return func2
setattr(_current_module, f, func(f))
...@@ -16,7 +16,8 @@ ...@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ..field import Field, exp from ..field import Field
from ..sugar import exp
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator from ..operators.diagonal_operator import DiagonalOperator
import numpy as np import numpy as np
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from .. import exp from ..sugar import exp
from ..minimization.energy import Energy from ..minimization.energy import Energy
from ..operators.smoothness_operator import SmoothnessOperator from ..operators.smoothness_operator import SmoothnessOperator
from ..operators.inversion_enabler import InversionEnabler from ..operators.inversion_enabler import InversionEnabler
......
...@@ -16,8 +16,7 @@ ...@@ -16,8 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
from ..field import exp, tanh from ..sugar import full, exp, tanh
from ..sugar import full
class Linear(object): class Linear(object):
......
...@@ -20,7 +20,7 @@ from ..minimization.energy import Energy ...@@ -20,7 +20,7 @@ from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator from ..operators.diagonal_operator import DiagonalOperator
from ..operators.sandwich_operator import SandwichOperator from ..operators.sandwich_operator import SandwichOperator
from ..operators.inversion_enabler import InversionEnabler from ..operators.inversion_enabler import InversionEnabler
from ..field import log from ..sugar import log
class PoissonEnergy(Energy): class PoissonEnergy(Energy):
......
...@@ -6,8 +6,9 @@ __all = ["MultiDomain"] ...@@ -6,8 +6,9 @@ __all = ["MultiDomain"]
class frozendict(collections.Mapping): class frozendict(collections.Mapping):
""" """
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping` An immutable wrapper around dictionaries that implements the complete
interface. It can be used as a drop-in replacement for dictionaries where immutability is desired. :py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
""" """
dict_cls = dict dict_cls = dict
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
import sys
import numpy as np import numpy as np
from .domains.power_space import PowerSpace from .domains.power_space import PowerSpace
from .field import Field from .field import Field
...@@ -30,7 +31,7 @@ from .logger import logger ...@@ -30,7 +31,7 @@ from .logger import logger
__all__ = ['PS_field', 'power_analyze', 'create_power_operator', __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random', 'create_harmonic_smoothing_operator', 'from_random',
'full', 'empty', 'from_global_data', 'from_local_data', 'full', 'empty', 'from_global_data', 'from_local_data',
'makeDomain'] 'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate']
def PS_field(pspace, func): def PS_field(pspace, func):
...@@ -199,3 +200,34 @@ def makeDomain(domain): ...@@ -199,3 +200,34 @@ def makeDomain(domain):
if isinstance(domain, dict): if isinstance(domain, dict):
return MultiDomain.make(domain) return MultiDomain.make(domain)
return DomainTuple.make(domain) return DomainTuple.make(domain)
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
if isinstance(x, MultiField):
if out is not None:
if (not isinstance(out, MultiField) or
x._domain != out._domain):
raise ValueError("Bad 'out' argument")
for key, value in x.items():
func2(value, out=out[key])
return out
return MultiField({key: func2(val) for key, val in x.items()})
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
fu = getattr(dobj, f)
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=fu(x.val))
return func2
setattr(_current_module, f, func(f))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment