Commit 1b43db0e authored by Martin Reinecke's avatar Martin Reinecke
Browse files

misc

parent 67071b5c
Pipeline #72329 passed with stages
in 21 minutes and 44 seconds
......@@ -20,9 +20,10 @@ import numpy as np
from . import utilities
from .domain_tuple import DomainTuple
from .operators.operator import Operator
class Field(object):
class Field(Operator):
"""The discrete representation of a continuous field over multiple spaces.
Stores data arrays and carries all the needed meta-information (i.e. the
......@@ -676,10 +677,8 @@ class Field(object):
def flexible_addsub(self, other, neg):
return self-other if neg else self+other
def clip(self, min=None, max=None):
min = min.val if isinstance(min, Field) else min
max = max.val if isinstance(max, Field) else max
return Field(self._domain, np.clip(self._val, min, max))
def clip(self, a_min=None, a_max=None):
return self.ptw("clip", a_min, a_max)
def _binary_op(self, other, op):
# if other is a field, make sure that the domains match
......@@ -692,13 +691,22 @@ class Field(object):
return Field(self._domain, f(other))
return NotImplemented
def ptw(self, op, with_deriv=False):
def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict
if with_deriv:
tmp = ptw_dict[op][1](self._val)
return (Field(self._domain, tmp[0]),
Field(self._domain, tmp[1]))
return Field(self._domain, ptw_dict[op][0](self._val))
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val
for key, val in kwargs.items()}
return Field(self._domain, ptw_dict[op][0](self._val, *argstmp, **kwargstmp))
def ptw_with_deriv(self, op, *args, **kwargs):
from .pointwise import ptw_dict
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val
for key, val in kwargs.items()}
tmp = ptw_dict[op][1](self._val, *argstmp, **kwargstmp)
return (Field(self._domain, tmp[0]), Field(self._domain, tmp[1]))
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
......
......@@ -21,9 +21,10 @@ from .field import Field
from .multi_field import MultiField
from .sugar import makeOp
from . import utilities
from .operators.operator import Operator
class Linearization(object):
class Linearization(Operator):
"""Let `A` be an operator and `x` a field. `Linearization` stores the value
of the operator application (i.e. `A(x)`), the local Jacobian
(i.e. `dA(x)/dx`) and, optionally, the local metric.
......@@ -169,10 +170,9 @@ class Linearization(object):
return self.ptw("reciprocal").__mul__(other)
def __pow__(self, power):
if not np.isscalar(power):
if not (np.isscalar(power) or isinstance(power, (Field, MultiField))):
return NotImplemented
return self.new(self._val**power,
makeOp(self._val**(power-1)).scale(power)(self._jac))
return self.ptw("power", power)
def __mul__(self, other):
from .sugar import makeOp
......@@ -282,22 +282,19 @@ class Linearization(object):
self._val.integrate(spaces),
ContractionOperator(self._jac.target, spaces, 1)(self._jac))
def ptw(self, op):
def ptw(self, op, *args, **kwargs):
from .pointwise import ptw_dict
t1, t2 = self._val.ptw(op, True)
t1, t2 = self._val.ptw_with_deriv(op, *args, **kwargs)
return self.new(t1, makeOp(t2)(self._jac))
def clip(self, min=None, max=None):
tmp = self._val.clip(min, max)
if (min is None) and (max is None):
def clip(self, a_min=None, a_max=None):
if a_min is None and a_max is None:
return self
elif max is None:
tmp2 = makeOp(1. - (tmp == min))
elif min is None:
tmp2 = makeOp(1. - (tmp == max))
else:
tmp2 = makeOp(1. - (tmp == min) - (tmp == max))
return self.new(tmp, tmp2(self._jac))
if not (a_min is None or np.isscalar(a_min) or isinstance(a_min, (Field, MultiField))):
return NotImplemented
if not (a_max is None or np.isscalar(a_max) or isinstance(a_max, (Field, MultiField))):
return NotImplemented
return self.ptw("clip", a_min, a_max)
def add_metric(self, metric):
return self.new(self._val, self._jac, metric)
......
......@@ -21,9 +21,10 @@ from . import utilities
from .field import Field
from .multi_domain import MultiDomain
from .domain_tuple import DomainTuple
from .operators.operator import Operator
class MultiField(object):
class MultiField(Operator):
def __init__(self, domain, val):
"""The discrete representation of a continuous field over a sum space.
......@@ -199,13 +200,8 @@ class MultiField(object):
def conjugate(self):
return self._transform(lambda x: x.conjugate())
def clip(self, min=None, max=None):
ncomp = len(self._val)
lmin = min._val if isinstance(min, MultiField) else (min,)*ncomp
lmax = max._val if isinstance(max, MultiField) else (max,)*ncomp
return MultiField(
self._domain,
tuple(self._val[i].clip(lmin[i], lmax[i]) for i in range(ncomp)))
def clip(self, a_min=None, a_max=None):
return self.ptw("clip", a_min, a_max)
def s_all(self):
for v in self._val:
......@@ -310,12 +306,28 @@ class MultiField(object):
res[key] = -val if neg else val
return MultiField.from_dict(res)
def ptw(self, op, with_deriv=False):
tmp = tuple(val.ptw(op, with_deriv) for val in self.values())
if with_deriv:
return (MultiField(self.domain, tuple(v[0] for v in tmp)),
MultiField(self.domain, tuple(v[1] for v in tmp)))
return MultiField(self.domain, tmp)
def ptw(self, op, *args, **kwargs):
# _check_args(args, kwargs)
tmp = []
for i in range(len(self._val)):
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i]
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i]
for key, val in kwargs.items()}
tmp.append(self._val[i].ptw(op, *argstmp, **kwargstmp))
return MultiField(self.domain, tuple(tmp))
def ptw_with_deriv(self, op, *args, **kwargs):
# _check_args(args, kwargs)
tmp = []
for i in range(len(self._val)):
argstmp = tuple(arg if arg is None or np.isscalar(arg) else arg._val[i]
for arg in args)
kwargstmp = {key: val if val is None or np.isscalar(val) else val._val[i]
for key, val in kwargs.items()}
tmp.append(self._val[i].ptw_with_deriv(op, *argstmp, **kwargstmp))
return (MultiField(self.domain, tuple(v[0] for v in tmp)),
MultiField(self.domain, tuple(v[1] for v in tmp)))
def _binary_op(self, other, op):
f = getattr(Field, op)
......
......@@ -17,8 +17,6 @@
import numpy as np
from ..field import Field
from ..multi_field import MultiField
from ..utilities import NiftyMeta, indent
......@@ -45,9 +43,65 @@ class Operator(metaclass=NiftyMeta):
-------
target : DomainTuple or MultiDomain
"""
return self._target
@property
def val(self):
"""The numerical value associated with this object
For "pure" operators this is `None`. For Field-like objects this
is a `numpy.ndarray` or a dictionary of `numpy.ndarray`s mathcing the
object's `target`.
Returns
-------
None or numpy.ndarray or dictionary of np.ndarrays : the numerical value
"""
return None
@property
def jac(self):
"""The Jacobian associated with this object
For "pure" operators this is `None`. For Field-like objects this
can be `None` (in which case the object is a constant), or it can be a
`LinearOperator` with `domain` and `target` matching the object's.
Returns
-------
None or LinearOperator : the Jacobian
Notes
-----
if `value` is None, this must be `None` as well!
"""
return None
@property
def want_metric(self):
"""Whether a metric should be computed for the full expression.
This is `False` whenever `jac` is `None`. In other cases it signals
that operators processing this object should compute the metric.
Returns
-------
bool : whether the metric should be computed
"""
return False
@property
def metric(self):
"""The metric associated with the object.
This is `None`, except when all the following conditions hold:
- `want_metric` is `True`
- `target` is the scalar domain
- the operator chain contained an operator which could compute the
metric
Returns
-------
None or LinearOperator : the metric
"""
return None
@staticmethod
def _check_domain_equality(dom_op, dom_field):
if dom_op != dom_field:
......@@ -153,14 +207,22 @@ class Operator(metaclass=NiftyMeta):
return _OpSum(self, -x)
def __pow__(self, power):
if not np.isscalar(power):
from ..field import Field
from ..multi_field import MultiField
if not (np.isscalar(power) or isinstance(power, (Field, MultiField))):
return NotImplemented
return _OpChain.make((_PowerOp(self.target, power), self))
return self.ptw("power", power)
def clip(self, min=None, max=None):
if min is None and max is None:
def clip(self, a_min=None, a_max=None):
from ..field import Field
from ..multi_field import MultiField
if a_min is None and a_max is None:
return self
return _OpChain.make((_Clipper(self.target, min, max), self))
if not (a_min is None or np.isscalar(a_min) or isinstance(a_min, (Field, MultiField))):
return NotImplemented
if not (a_max is None or np.isscalar(a_max) or isinstance(a_max, (Field, MultiField))):
return NotImplemented
return self.ptw("clip", a_min, a_max)
def apply(self, x):
"""Applies the operator to a Field or MultiField.
......@@ -179,6 +241,8 @@ class Operator(metaclass=NiftyMeta):
return self.apply(x.extract(self.domain))
def _check_input(self, x):
from ..field import Field
from ..multi_field import MultiField
from ..linearization import Linearization
from .scaling_operator import ScalingOperator
if not isinstance(x, (Field, MultiField, Linearization)):
......@@ -222,8 +286,8 @@ class Operator(metaclass=NiftyMeta):
def _simplify_for_constant_input_nontrivial(self, c_inp):
return None, self
def ptw(self, op):
return _OpChain.make((_FunctionApplier(self.target, op), self))
def ptw(self, op, *args, **kwargs):
return _OpChain.make((_FunctionApplier(self.target, op, *args, **kwargs), self))
class _ConstCollector(object):
......@@ -287,37 +351,16 @@ class _ConstantOperator(Operator):
class _FunctionApplier(Operator):
def __init__(self, domain, funcname):
def __init__(self, domain, funcname, *args, **kwargs):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._funcname = funcname
self._args = args
self._kwargs = kwargs
def apply(self, x):
self._check_input(x)
return x.ptw(self._funcname)
class _Clipper(Operator):
def __init__(self, domain, min=None, max=None):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._min = min
self._max = max
def apply(self, x):
self._check_input(x)
return x.clip(self._min, self._max)
class _PowerOp(Operator):
def __init__(self, domain, power):
from ..sugar import makeDomain
self._domain = self._target = makeDomain(domain)
self._power = power
def apply(self, x):
self._check_input(x)
return x**self._power
return x.ptw(self._funcname, *self._args, **self._kwargs)
class _CombinedOperator(Operator):
......
......@@ -52,17 +52,33 @@ def _reciprocal_helper(v):
def _abs_helper(v):
if np.iscomplex(v):
if np.issubdtype(v.dtype, np.complexfloating):
raise TypeError("Argument must not be complex")
return (np.abs(v), np.where(v==0, np.nan, np.sign(v)))
def _sign_helper(v):
if np.iscomplex(v):
if np.issubdtype(v.dtype, np.complexfloating):
raise TypeError("Argument must not be complex")
return (np.sign(v), np.where(v==0, np.nan, 0))
def _power_helper(v, expo):
return (np.power(v, expo), expo*np.power(v, expo-1))
def _clip_helper(v, a_min=None, a_max=None):
if np.issubdtype(v.dtype, np.complexfloating):
raise TypeError("Argument must not be complex")
tmp = np.clip(v, a_min, a_max)
tmp2 = np.ones(v.shape)
if a_min is not None:
tmp2 = np.where(tmp==a_min, 0., tmp2)
if a_max is not None:
tmp2 = np.where(tmp==a_max, 0., tmp2)
return (tmp, tmp2)
ptw_dict = {
"sqrt": (np.sqrt, _sqrt_helper),
"sin" : (np.sin, lambda v: (np.sin(v), np.cos(v))),
......@@ -81,5 +97,7 @@ ptw_dict = {
"reciprocal": (lambda v: 1./v, _reciprocal_helper),
"abs": (np.abs, _abs_helper),
"absolute": (np.abs, _abs_helper),
"sign": (np.sign, _sign_helper)
"sign": (np.sign, _sign_helper),
"power": (np.power, _power_helper),
"clip": (np.clip, _clip_helper),
}
......@@ -193,8 +193,8 @@ def test_empty_domain():
def test_trivialities():
s1 = ift.RGSpace((10,))
f1 = ift.Field.full(s1, 27)
assert_equal(f1.clip(min=29).val, 29.)
assert_equal(f1.clip(max=25).val, 25.)
assert_equal(f1.clip(a_min=29).val, 29.)
assert_equal(f1.clip(a_max=25).val, 25.)
assert_equal(f1.val, f1.real.val)
assert_equal(f1.val, (+f1).val)
f1 = ift.Field.full(s1, 27. + 3j)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment