Commit 484ddb59 authored by Martin Reinecke's avatar Martin Reinecke

Merge branch 'move_multi' into 'NIFTy_4'

Move "multi"-related functionality from GlobalNewton to NIFTy

See merge request ift/NIFTy!251
parents 62a3585d 0e793273
Pipeline #28500 passed with stages
in 18 minutes and 41 seconds
......@@ -24,6 +24,9 @@ from .utilities import memo
from .logger import logger
from .multi import *
__all__ = ["__version__", "dobj", "DomainTuple"] + \
domains.__all__ + operators.__all__ + minimization.__all__ + \
["DomainTuple", "Field", "sqrt", "exp", "log"]
["DomainTuple", "Field", "sqrt", "exp", "log"] + \
multi.__all__
......@@ -66,6 +66,8 @@ class DomainTuple(object):
"""
if isinstance(domain, DomainTuple):
return domain
if isinstance(domain, dict):
return domain
domain = DomainTuple._parse_domain(domain)
obj = DomainTuple._tupleCache.get(domain)
if obj is not None:
......
......@@ -746,20 +746,6 @@ class Field(object):
raise ValueError("domains are incompatible.")
self.local_data[()] = other.local_data[()]
def _binary_helper(self, other, op):
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val)
return self if tval is self.val else Field(self._domain, tval)
if np.isscalar(other) or isinstance(other, dobj.data_object):
tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self._domain, tval)
return NotImplemented
def __repr__(self):
return "<nifty4.Field>"
......@@ -778,30 +764,38 @@ for op in ["__add__", "__radd__", "__iadd__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
return self._binary_helper(other, op=op)
# if other is a field, make sure that the domains match
if isinstance(other, Field):
if other._domain != self._domain:
raise ValueError("domains are incompatible.")
tval = getattr(self.val, op)(other.val)
return self if tval is self.val else Field(self._domain, tval)
if np.isscalar(other) or isinstance(other, dobj.data_object):
tval = getattr(self.val, op)(other)
return self if tval is self.val else Field(self._domain, tval)
return NotImplemented
return func2
setattr(Field, op, func(op))
# Arithmetic functions working on Fields
def _math_helper(x, function, out):
function = getattr(dobj, function)
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
function(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=function(x.val))
_current_module = sys.modules[__name__]
for f in ["sqrt", "exp", "log", "tanh", "conjugate"]:
def func(f):
def func2(x, out=None):
return _math_helper(x, f, out)
fu = getattr(dobj, f)
if not isinstance(x, Field):
raise TypeError("This function only accepts Field objects.")
if out is not None:
if not isinstance(out, Field) or x._domain != out._domain:
raise ValueError("Bad 'out' argument")
fu(x.val, out=out.val)
return out
else:
return Field(domain=x._domain, val=fu(x.val))
return func2
setattr(_current_module, f, func(f))
from .multi_domain import MultiDomain
from .multi_field import MultiField
__all__ = ["MultiDomain", "MultiField"]
class MultiDomain(dict):
pass
from ..field import Field
import numpy as np
from .multi_domain import MultiDomain
class MultiField(object):
def __init__(self, val):
"""
Parameters
----------
val : dict
"""
self._val = val
def __getitem__(self, key):
return self._val[key]
def keys(self):
return self._val.keys()
def items(self):
return self._val.items()
def values(self):
return self._val.values()
@property
def domain(self):
return MultiDomain({key: val.domain for key, val in self._val.items()})
@property
def dtype(self):
return {key: val.dtype for key, val in self._val.items()}
def _check_domain(self, other):
if other.domain != self.domain:
raise ValueError("domains are incompatible.")
def vdot(self, x):
result = 0.
self._check_domain(x)
for key, sub_field in self.items():
result += sub_field.vdot(x[key])
return result
def lock(self):
for v in self.values():
v.lock()
return self
def copy(self):
return MultiField({key: val.copy() for key, val in self.items()})
@staticmethod
def build_dtype(dtype, domain):
if isinstance(dtype, dict):
return dtype
if dtype is None:
dtype = np.float64
return {key: dtype for key in domain.keys()}
@staticmethod
def zeros(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.zeros(dom, dtype=dtype[key])
for key, dom in domain.items()})
@staticmethod
def ones(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.ones(dom, dtype=dtype[key])
for key, dom in domain.items()})
@staticmethod
def empty(domain, dtype=None):
dtype = MultiField.build_dtype(dtype, domain)
return MultiField({key: Field.empty(dom, dtype=dtype[key])
for key, dom in domain.items()})
def norm(self):
""" Computes the L2-norm of the field values.
Returns
-------
norm : float
The L2-norm of the field values.
"""
return np.sqrt(np.abs(self.vdot(x=self)))
def __neg__(self):
return MultiField({key: -val for key, val in self.items()})
def conjugate(self):
return MultiField({key: sub_field.conjugate()
for key, sub_field in self.items()})
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
"__div__", "__rdiv__", "__idiv__",
"__truediv__", "__rtruediv__", "__itruediv__",
"__floordiv__", "__rfloordiv__", "__ifloordiv__",
"__pow__", "__rpow__", "__ipow__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
self._check_domain(other)
result_val = {key: getattr(sub_field, op)(other[key])
for key, sub_field in self.items()}
else:
result_val = {key: getattr(val, op)(other)
for key, val in self.items()}
return MultiField(result_val)
return func2
setattr(MultiField, op, func(op))
......@@ -18,7 +18,6 @@
from ..minimization.quadratic_energy import QuadraticEnergy
from ..minimization.iteration_controller import IterationController
from ..field import Field
from ..logger import logger
from .endomorphic_operator import EndomorphicOperator
import numpy as np
......@@ -68,7 +67,7 @@ class InversionEnabler(EndomorphicOperator):
if self._op.capability & mode:
return self._op.apply(x, mode)
x0 = Field.zeros(self._tgt(mode), dtype=x.dtype)
x0 = x*0.
invmode = self._modeTable[self.INVERSE_BIT][self._ilog[mode]]
invop = self._op._flip_modes(self._ilog[invmode])
prec = self._approximation
......
......@@ -271,8 +271,9 @@ class LinearOperator(NiftyMetaBase()):
raise ValueError("requested operator mode is not supported")
def _check_input(self, x, mode):
if not isinstance(x, Field):
raise ValueError("supplied object is not a `Field`.")
# MR FIXME: temporary fix for working with MultiFields
#if not isinstance(x, Field):
# raise ValueError("supplied object is not a `Field`.")
self._check_mode(mode)
if x.domain != self._dom(mode):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment