Commit 77c29cf3 authored by Martin Reinecke's avatar Martin Reinecke

move arithmetic functions working on Fields to sugar

parent bb07fd25
Pipeline #29510 passed with stages
in 3 minutes and 59 seconds
......@@ -8,7 +8,7 @@ from .domain_tuple import DomainTuple
from .operators import *
from .field import Field, sqrt, exp, log
from .field import Field
from .probing.utils import probe_with_posterior_samples, probe_diagonal, \
StatCalculator
......
......@@ -19,7 +19,7 @@
from __future__ import division
import numpy as np
from .structured_domain import StructuredDomain
from ..field import Field, exp
from ..field import Field
class LMSpace(StructuredDomain):
......@@ -100,6 +100,8 @@ class LMSpace(StructuredDomain):
# cf. "All-sky convolution for polarimetry experiments"
# by Challinor et al.
# http://arxiv.org/abs/astro-ph/0008228
from ..sugar import exp
res = x+1.
res *= x
res *= -0.5*sigma*sigma
......
......@@ -21,7 +21,7 @@ from builtins import range
from functools import reduce
import numpy as np
from .structured_domain import StructuredDomain
from ..field import Field, exp
from ..field import Field
from .. import dobj
......@@ -144,6 +144,7 @@ class RGSpace(StructuredDomain):
@staticmethod
def _kernel(x, sigma):
from ..sugar import exp
tmp = x*x
tmp *= -2.*np.pi*np.pi*sigma*sigma
exp(tmp, out=tmp)
......
......@@ -733,24 +733,3 @@ for op in ["__add__", "__radd__", "__iadd__",
return NotImplemented
return func2
setattr(Field, op, func(op))
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
fu = getattr(dobj, f)
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=fu(x.val))
return func2
setattr(_current_module, f, func(f))
......@@ -16,7 +16,8 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..field import Field, exp
from ..field import Field
from ..sugar import exp
from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator
import numpy as np
......
......@@ -16,7 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .. import exp
from ..sugar import exp
from ..minimization.energy import Energy
from ..operators.smoothness_operator import SmoothnessOperator
from ..operators.inversion_enabler import InversionEnabler
......
......@@ -16,8 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from ..field import exp, tanh
from ..sugar import full
from ..sugar import full, exp, tanh
class Linear(object):
......
......@@ -20,7 +20,7 @@ from ..minimization.energy import Energy
from ..operators.diagonal_operator import DiagonalOperator
from ..operators.sandwich_operator import SandwichOperator
from ..operators.inversion_enabler import InversionEnabler
from ..field import log
from ..sugar import log
class PoissonEnergy(Energy):
......
......@@ -6,8 +6,9 @@ __all = ["MultiDomain"]
class frozendict(collections.Mapping):
"""
An immutable wrapper around dictionaries that implements the complete :py:class:`collections.Mapping`
interface. It can be used as a drop-in replacement for dictionaries where immutability is desired.
An immutable wrapper around dictionaries that implements the complete
:py:class:`collections.Mapping` interface. It can be used as a drop-in
replacement for dictionaries where immutability is desired.
"""
dict_cls = dict
......
......@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
import sys
import numpy as np
from .domains.power_space import PowerSpace
from .field import Field
......@@ -30,7 +31,7 @@ from .logger import logger
__all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random',
'full', 'empty', 'from_global_data', 'from_local_data',
'makeDomain']
'makeDomain', 'sqrt', 'exp', 'log', 'tanh', 'conjugate']
def PS_field(pspace, func):
......@@ -199,3 +200,34 @@ def makeDomain(domain):
if isinstance(domain, dict):
return MultiDomain.make(domain)
return DomainTuple.make(domain)
# Arithmetic functions working on Fields
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
if isinstance(x, MultiField):
if out is not None:
if (not isinstance(out, MultiField) or
x._domain != out._domain):
raise ValueError("Bad 'out' argument")
for key, value in x.items():
func2(value, out=out[key])
return out
return MultiField({key: func2(val) for key, val in x.items()})
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
fu = getattr(dobj, f)
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=fu(x.val))
return func2
setattr(_current_module, f, func(f))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment