Commit 393327d5 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'pointwise' into pointwise_ng

parents 01779e03 55ec681a
...@@ -52,7 +52,7 @@ if __name__ == '__main__': ...@@ -52,7 +52,7 @@ if __name__ == '__main__':
A = ift.create_power_operator(harmonic_space, sqrtpspec) A = ift.create_power_operator(harmonic_space, sqrtpspec)
# Set up a sky operator and instrumental response # Set up a sky operator and instrumental response
sky = HT(A).ptw("sigmoid") sky = ift.sigmoid(HT(A))
GR = ift.GeometryRemover(position_space) GR = ift.GeometryRemover(position_space)
R = GR R = GR
......
...@@ -80,7 +80,7 @@ if __name__ == '__main__': ...@@ -80,7 +80,7 @@ if __name__ == '__main__':
A = pd(a) A = pd(a)
# Define sky operator # Define sky operator
sky = HT(ift.makeOp(A)).ptw("exp") sky = ift.exp(HT(ift.makeOp(A)))
M = ift.DiagonalOperator(exposure) M = ift.DiagonalOperator(exposure)
GR = ift.GeometryRemover(position_space) GR = ift.GeometryRemover(position_space)
......
...@@ -85,7 +85,7 @@ if __name__ == '__main__': ...@@ -85,7 +85,7 @@ if __name__ == '__main__':
A = cfmaker.amplitude A = cfmaker.amplitude
# Apply a nonlinearity # Apply a nonlinearity
signal = correlated_field.ptw("sigmoid") signal = ift.sigmoid(correlated_field)
# Build the line-of-sight response and define signal response # Build the line-of-sight response and define signal response
LOS_starts, LOS_ends = random_los(100) if mode == 0 else radial_los(100) LOS_starts, LOS_ends = random_los(100) if mode == 0 else radial_los(100)
...@@ -149,7 +149,7 @@ if __name__ == '__main__': ...@@ -149,7 +149,7 @@ if __name__ == '__main__':
filename_res = filename.format("results") filename_res = filename.format("results")
plot = ift.Plot() plot = ift.Plot()
plot.add(sc.mean, title="Posterior Mean") plot.add(sc.mean, title="Posterior Mean")
plot.add(sc.var.ptw("sqrt"), title="Posterior Standard Deviation") plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")
powers = [A.force(s + KL.position) for s in KL.samples] powers = [A.force(s + KL.position) for s in KL.samples]
plot.add( plot.add(
......
...@@ -84,7 +84,7 @@ if __name__ == '__main__': ...@@ -84,7 +84,7 @@ if __name__ == '__main__':
DC = SingleDomain(correlated_field.target, position_space) DC = SingleDomain(correlated_field.target, position_space)
## Apply a nonlinearity ## Apply a nonlinearity
signal = DC @ correlated_field.ptw("sigmoid") signal = DC @ ift.sigmoid(correlated_field)
# Build the line-of-sight response and define signal response # Build the line-of-sight response and define signal response
LOS_starts, LOS_ends = random_los(100) if mode == 0 else radial_los(100) LOS_starts, LOS_ends = random_los(100) if mode == 0 else radial_los(100)
...@@ -170,7 +170,7 @@ if __name__ == '__main__': ...@@ -170,7 +170,7 @@ if __name__ == '__main__':
filename_res = filename.format("results") filename_res = filename.format("results")
plot = ift.Plot() plot = ift.Plot()
plot.add(sc.mean, title="Posterior Mean") plot.add(sc.mean, title="Posterior Mean")
plot.add(sc.var.ptw("sqrt"), title="Posterior Standard Deviation") plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")
powers1 = [A1.force(s + KL.position) for s in KL.samples] powers1 = [A1.force(s + KL.position) for s in KL.samples]
powers2 = [A2.force(s + KL.position) for s in KL.samples] powers2 = [A2.force(s + KL.position) for s in KL.samples]
......
...@@ -686,9 +686,6 @@ class Field(Operator): ...@@ -686,9 +686,6 @@ class Field(Operator):
def flexible_addsub(self, other, neg): def flexible_addsub(self, other, neg):
return self-other if neg else self+other return self-other if neg else self+other
def clip(self, a_min=None, a_max=None):
return self.ptw("clip", a_min, a_max)
def _binary_op(self, other, op): def _binary_op(self, other, op):
# if other is a field, make sure that the domains match # if other is a field, make sure that the domains match
f = getattr(self._val, op) f = getattr(self._val, op)
...@@ -700,20 +697,24 @@ class Field(Operator): ...@@ -700,20 +697,24 @@ class Field(Operator):
return Field(self._domain, f(other)) return Field(self._domain, f(other))
return NotImplemented return NotImplemented
def ptw(self, op, *args, **kwargs): def _prep_args(self, args, kwargs):
from .pointwise import ptw_dict for arg in args + tuple(kwargs.values()):
if not (arg is None or np.isscalar(arg) or arg.jac is None):
raise TypeError("bad argument")
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val
for arg in args) for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val kwargstmp = {key: val if val is None or np.isscalar(val) else val._val
for key, val in kwargs.items()} for key, val in kwargs.items()}
return argstmp, kwargstmp
def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict
argstmp, kwargstmp = self._prep_args(args, kwargs)
return Field(self._domain, ptw_dict[op][0](self._val, *argstmp, **kwargstmp)) return Field(self._domain, ptw_dict[op][0](self._val, *argstmp, **kwargstmp))
def ptw_with_deriv(self, op, *args, **kwargs): def ptw_with_deriv(self, op, *args, **kwargs):
from .pointwise import ptw_dict from .pointwise import ptw_dict
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val argstmp, kwargstmp = self._prep_args(args, kwargs)
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val
for key, val in kwargs.items()}
tmp = ptw_dict[op][1](self._val, *argstmp, **kwargstmp) tmp = ptw_dict[op][1](self._val, *argstmp, **kwargstmp)
return (Field(self._domain, tmp[0]), Field(self._domain, tmp[1])) return (Field(self._domain, tmp[0]), Field(self._domain, tmp[1]))
......
...@@ -225,7 +225,7 @@ class _Normalization(Operator): ...@@ -225,7 +225,7 @@ class _Normalization(Operator):
def apply(self, x): def apply(self, x):
self._check_input(x) self._check_input(x)
amp = x.ptw("exp") amp = x.ptw("exp")
spec = amp*amp spec = amp**2
# FIXME This normalizes also the zeromode which is supposed to be left # FIXME This normalizes also the zeromode which is supposed to be left
# untouched by this operator # untouched by this operator
return self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp return self._specsum(self._mode_multiplicity(spec))**(-0.5)*amp
......
...@@ -294,15 +294,6 @@ class Linearization(Operator): ...@@ -294,15 +294,6 @@ class Linearization(Operator):
t1, t2 = self._fld.ptw_with_deriv(op, *args, **kwargs) t1, t2 = self._fld.ptw_with_deriv(op, *args, **kwargs)
return self.new(t1, makeOp(t2)(self._jac)) return self.new(t1, makeOp(t2)(self._jac))
def clip(self, a_min=None, a_max=None):
if a_min is None and a_max is None:
return self
if not (a_min is None or np.isscalar(a_min) or a_min.jac is None):
return NotImplemented
if not (a_max is None or np.isscalar(a_max) or a_max.jac is None):
return NotImplemented
return self.ptw("clip", a_min, a_max)
def add_metric(self, metric): def add_metric(self, metric):
return self.new(self._fld, self._jac, metric) return self.new(self._fld, self._jac, metric)
......
...@@ -314,25 +314,27 @@ class MultiField(Operator): ...@@ -314,25 +314,27 @@ class MultiField(Operator):
res[key] = -val if neg else val res[key] = -val if neg else val
return MultiField.from_dict(res) return MultiField.from_dict(res)
def _prep_args(self, args, kwargs, i):
for arg in args + tuple(kwargs.values()):
if not (arg is None or np.isscalar(arg) or arg.jac is None):
raise TypeError("bad argument")
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i]
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i]
for key, val in kwargs.items()}
return argstmp, kwargstmp
def ptw(self, op, *args, **kwargs): def ptw(self, op, *args, **kwargs):
# _check_args(args, kwargs)
tmp = [] tmp = []
for i in range(len(self._val)): for i in range(len(self._val)):
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i] argstmp, kwargstmp = self._prep_args(args, kwargs, i)
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i]
for key, val in kwargs.items()}
tmp.append(self._val[i].ptw(op, *argstmp, **kwargstmp)) tmp.append(self._val[i].ptw(op, *argstmp, **kwargstmp))
return MultiField(self.domain, tuple(tmp)) return MultiField(self.domain, tuple(tmp))
def ptw_with_deriv(self, op, *args, **kwargs): def ptw_with_deriv(self, op, *args, **kwargs):
# _check_args(args, kwargs)
tmp = [] tmp = []
for i in range(len(self._val)): for i in range(len(self._val)):
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i] argstmp, kwargstmp = self._prep_args(args, kwargs, i)
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i]
for key, val in kwargs.items()}
tmp.append(self._val[i].ptw_with_deriv(op, *argstmp, **kwargstmp)) tmp.append(self._val[i].ptw_with_deriv(op, *argstmp, **kwargstmp))
return (MultiField(self.domain, tuple(v[0] for v in tmp)), return (MultiField(self.domain, tuple(v[0] for v in tmp)),
MultiField(self.domain, tuple(v[1] for v in tmp))) MultiField(self.domain, tuple(v[1] for v in tmp)))
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import numpy as np import numpy as np
from ..utilities import NiftyMeta, indent from ..utilities import NiftyMeta, indent
from .. import pointwise
class Operator(metaclass=NiftyMeta): class Operator(metaclass=NiftyMeta):
...@@ -221,15 +222,6 @@ class Operator(metaclass=NiftyMeta): ...@@ -221,15 +222,6 @@ class Operator(metaclass=NiftyMeta):
return NotImplemented return NotImplemented
return self.ptw("power", power) return self.ptw("power", power)
def clip(self, a_min=None, a_max=None):
if a_min is None and a_max is None:
return self
if not (a_min is None or np.isscalar(a_min) or a_min.jac is None):
return NotImplemented
if not (a_max is None or np.isscalar(a_max) or a_max.jac is None):
return NotImplemented
return self.ptw("clip", a_min, a_max)
def apply(self, x): def apply(self, x):
"""Applies the operator to a Field or MultiField. """Applies the operator to a Field or MultiField.
...@@ -292,6 +284,14 @@ class Operator(metaclass=NiftyMeta): ...@@ -292,6 +284,14 @@ class Operator(metaclass=NiftyMeta):
return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self)) return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self))
for f in pointwise.ptw_dict.keys():
def func(f):
def func2(self, *args, **kwargs):
return self.ptw(f, *args, **kwargs)
return func2
setattr(Operator, f, func(f))
class _ConstCollector(object): class _ConstCollector(object):
def __init__(self): def __init__(self):
self._const = None self._const = None
......
...@@ -67,7 +67,7 @@ def _power_helper(v, expo): ...@@ -67,7 +67,7 @@ def _power_helper(v, expo):
return (np.power(v, expo), expo*np.power(v, expo-1)) return (np.power(v, expo), expo*np.power(v, expo-1))
def _clip_helper(v, a_min=None, a_max=None): def _clip_helper(v, a_min, a_max):
if np.issubdtype(v.dtype, np.complexfloating): if np.issubdtype(v.dtype, np.complexfloating):
raise TypeError("Argument must not be complex") raise TypeError("Argument must not be complex")
tmp = np.clip(v, a_min, a_max) tmp = np.clip(v, a_min, a_max)
......
...@@ -33,13 +33,15 @@ from .operators.distributors import PowerDistributor ...@@ -33,13 +33,15 @@ from .operators.distributors import PowerDistributor
from .operators.operator import Operator from .operators.operator import Operator
from .operators.scaling_operator import ScalingOperator from .operators.scaling_operator import ScalingOperator
from .plot import Plot from .plot import Plot
from . import pointwise
__all__ = ['PS_field', 'power_analyze', 'create_power_operator', __all__ = ['PS_field', 'power_analyze', 'create_power_operator',
'create_harmonic_smoothing_operator', 'from_random', 'create_harmonic_smoothing_operator', 'from_random',
'full', 'makeField', 'full', 'makeField',
'makeDomain', 'get_signal_variance', 'makeOp', 'domain_union', 'makeDomain', 'get_signal_variance', 'makeOp', 'domain_union',
'get_default_codomain', 'single_plot', 'exec_time', 'get_default_codomain', 'single_plot', 'exec_time',
'calculate_position'] 'calculate_position'] + list(pointwise.ptw_dict.keys())
def PS_field(pspace, func): def PS_field(pspace, func):
...@@ -341,7 +343,7 @@ def makeOp(input, dom=None): ...@@ -341,7 +343,7 @@ def makeOp(input, dom=None):
if input is None: if input is None:
return None return None
if np.isscalar(input): if np.isscalar(input):
if not isinstance(dom, (DomaiTuple, MultiDomain)): if not isinstance(dom, (DomainTuple, MultiDomain)):
raise TypeError("need proper `dom` argument") raise TypeError("need proper `dom` argument")
return SalingOperator(dom, input) return SalingOperator(dom, input)
if dom is not None: if dom is not None:
...@@ -373,8 +375,16 @@ def domain_union(domains): ...@@ -373,8 +375,16 @@ def domain_union(domains):
return MultiDomain.union(domains) return MultiDomain.union(domains)
def clip(a, a_min=None, a_max=None): # Pointwise functions
return a.clip(a_min, a_max)
_current_module = sys.modules[__name__]
for f in pointwise.ptw_dict.keys():
def func(f):
def func2(x, *args, **kwargs):
return x.ptw(f, *args, **kwargs)
return func2
setattr(_current_module, f, func(f))
def get_default_codomain(domainoid, space=None): def get_default_codomain(domainoid, space=None):
......
...@@ -193,8 +193,8 @@ def test_empty_domain(): ...@@ -193,8 +193,8 @@ def test_empty_domain():
def test_trivialities(): def test_trivialities():
s1 = ift.RGSpace((10,)) s1 = ift.RGSpace((10,))
f1 = ift.Field.full(s1, 27) f1 = ift.Field.full(s1, 27)
assert_equal(f1.clip(a_min=29).val, 29.) assert_equal(f1.clip(a_min=29, a_max=50).val, 29.)
assert_equal(f1.clip(a_max=25).val, 25.) assert_equal(f1.clip(a_min=0, a_max=25).val, 25.)
assert_equal(f1.val, f1.real.val) assert_equal(f1.val, f1.real.val)
assert_equal(f1.val, (+f1).val) assert_equal(f1.val, (+f1).val)
f1 = ift.Field.full(s1, 27. + 3j) f1 = ift.Field.full(s1, 27. + 3j)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment