Commit 93c14275 authored by Martin Reinecke's avatar Martin Reinecke

begin redesign

parent b8dbbbfa
......@@ -102,5 +102,7 @@ from .multi.block_diagonal_operator import BlockDiagonalOperator
from .energies.kl import SampledKullbachLeiblerDivergence
from .energies.hamiltonian import Hamiltonian
from.operator import Linearization, Operator
# We deliberately don't set __all__ here, because we don't want people to do a
# "from nifty5 import *"; that would swamp the global namespace.
......@@ -32,7 +32,7 @@ class MultiField(object):
Parameters
----------
domain: MultiDomain
val: tuple containing Field or None entries
val: tuple containing Field entries
"""
if not isinstance(domain, MultiDomain):
raise TypeError("domain must be of type MultiDomain")
......@@ -44,8 +44,8 @@ class MultiField(object):
if isinstance(v, Field):
if v._domain is not d:
raise ValueError("domain mismatch")
elif v is not None:
raise TypeError("bad entry in val (must be Field or None)")
else:
raise TypeError("bad entry in val (must be Field)")
self._domain = domain
self._val = val
......@@ -54,8 +54,7 @@ class MultiField(object):
if domain is None:
domain = MultiDomain.make({key: v._domain
for key, v in dict.items()})
return MultiField(domain, tuple(dict[key] if key in dict else None
for key in domain.keys()))
return MultiField(domain, tuple(dict[key] for key in domain.keys()))
def to_dict(self):
return {key: val for key, val in zip(self._domain.keys(), self._val)}
......@@ -81,9 +80,7 @@ class MultiField(object):
# return {key: val.dtype for key, val in self._val.items()}
def _transform(self, op):
return MultiField(
self._domain,
tuple(op(v) if v is not None else None for v in self._val))
return MultiField(self._domain, tuple(op(v) for v in self._val))
@property
def real(self):
......@@ -111,8 +108,7 @@ class MultiField(object):
result = 0.
self._check_domain(x)
for v1, v2 in zip(self._val, x._val):
if v1 is not None and v2 is not None:
result += v1.vdot(v2)
result += v1.vdot(v2)
return result
# @staticmethod
......@@ -190,13 +186,13 @@ class MultiField(object):
def all(self):
for v in self._val:
if v is None or not v.all():
if not v.all():
return False
return True
def any(self):
for v in self._val:
if v is not None and v.any():
if v.any():
return True
return False
......@@ -215,44 +211,9 @@ class MultiField(object):
return True
for op in ["__add__", "__radd__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = []
for v1, v2 in zip(self._val, other._val):
if v1 is not None:
val.append(v1 if v2 is None else (v1+v2))
else:
val.append(None if v2 is None else v2)
val = tuple(val)
else:
val = tuple(other if v1 is None else (v1+other)
for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__mul__", "__rmul__"]:
def func(op):
def func2(self, other):
if isinstance(other, MultiField):
if self._domain is not other._domain:
raise ValueError("domain mismatch")
val = tuple(None if v1 is None or v2 is None else v1*v2
for v1, v2 in zip(self._val, other._val))
else:
val = tuple(None if v1 is None else (v1*other)
for v1 in self._val)
return MultiField(self._domain, val)
return func2
setattr(MultiField, op, func(op))
for op in ["__sub__", "__rsub__",
for op in ["__add__", "__radd__",
"__sub__", "__rsub__",
"__mul__", "__rmul__",
"__div__", "__rdiv__",
"__truediv__", "__rtruediv__",
"__floordiv__", "__rfloordiv__",
......
from __future__ import absolute_import, division, print_function
import abc
import numpy as np
from .compat import *
from .utilities import NiftyMetaBase
#from ..domain_tuple import DomainTuple
#from ..multi.multi_domain import MultiDomain
from .field import Field
from .multi.multi_field import MultiField
from .operators.scaling_operator import ScalingOperator
from .operators.diagonal_operator import DiagonalOperator
class Linearization(object):
def __init__(self, val, jac):
self._val = val
self._jac = jac
@property
def domain(self):
return self._jac.domain
@property
def target(self):
return self._jac.target
@property
def val(self):
return self._val
@property
def jac(self):
return self._jac
def __neg__(self):
return Linearization(-self._val, self._jac*(-1))
def __add__(self, other):
if isinstance(other, Linearization):
return Linearization(self._val+other._val, self._jac+other._jac)
if isinstance(other, (int, float, complex, Field, MultiField)):
return Linearization(self._val+other, self._jac)
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
return self.__add__(-other)
def __rsub__(self, other):
return (-self).__add__(other)
def __mul__(self, other):
if isinstance(other, Linearization):
d1 = DiagonalOperator(self._val)
d2 = DiagonalOperator(other._val)
return Linearization(self._val*other._val,
self._jac*d2 + d1*other._jac)
if isinstance(other, (int, float, complex)):
#if other == 0:
# return ...
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d2 = DiagonalOperator(other)
return Linearization(self._val*other, self._jac*d2)
raise TypeError
def __rmul__(self, other):
if isinstance(other, (int, float, complex)):
return Linearization(self._val*other, self._jac*other)
if isinstance(other, (Field, MultiField)):
d1 = DiagonalOperator(other)
return Linearization(self._val*other, d1*self._jac)
@staticmethod
def make_var(field):
return Linearization(field, ScalingOperator(1., field.domain))
@staticmethod
def make_const(field):
return Linearization(field, ScalingOperator(0., {}))
class Operator(NiftyMetaBase()):
"""Transforms values living on one domain into values living on another
domain, and can also provide the Jacobian.
"""
def __call__(self, x):
"""Returns transformed x
Parameters
----------
x : Linearization
input
Returns
-------
Linearization
output
"""
raise NotImplementedError
from __future__ import absolute_import, division, print_function
import numpy as np
import itertools
from ..compat import *
from .. import utilities
from .linear_operator import LinearOperator
from ..domain_tuple import DomainTuple
......
......@@ -62,13 +62,13 @@ class Consistency_Tests(unittest.TestCase):
op = ift.SlopeOperator(dom, tgt, sig)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testSelectionOperator(self, sp1, sp2, dtype):
mdom = ift.MultiDomain.make({'a': sp1, 'b': sp2})
op = ift.SelectionOperator(mdom, 'a')
ift.extra.consistency_check(op, dtype, dtype)
# @expand(product(_h_spaces + _p_spaces + _pow_spaces,
# _h_spaces + _p_spaces + _pow_spaces,
# [np.float64, np.complex128]))
# def testSelectionOperator(self, sp1, sp2, dtype):
# mdom = ift.MultiDomain.make({'a': sp1, 'b': sp2})
# op = ift.SelectionOperator(mdom, 'a')
# ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
......@@ -80,20 +80,20 @@ class Consistency_Tests(unittest.TestCase):
ift.extra.consistency_check(op.inverse.adjoint, dtype, dtype)
ift.extra.consistency_check(op.adjoint.inverse, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testNullOperator(self, sp1, sp2, dtype):
op = ift.NullOperator(sp1, sp2)
ift.extra.consistency_check(op, dtype, dtype)
mdom1 = ift.MultiDomain.make({'a': sp1})
mdom2 = ift.MultiDomain.make({'b': sp2})
op = ift.NullOperator(mdom1, mdom2)
ift.extra.consistency_check(op, dtype, dtype)
op = ift.NullOperator(sp1, mdom2)
ift.extra.consistency_check(op, dtype, dtype)
op = ift.NullOperator(mdom1, sp2)
ift.extra.consistency_check(op, dtype, dtype)
# @expand(product(_h_spaces + _p_spaces + _pow_spaces,
# _h_spaces + _p_spaces + _pow_spaces,
# [np.float64, np.complex128]))
# def testNullOperator(self, sp1, sp2, dtype):
# op = ift.NullOperator(sp1, sp2)
# ift.extra.consistency_check(op, dtype, dtype)
# mdom1 = ift.MultiDomain.make({'a': sp1})
# mdom2 = ift.MultiDomain.make({'b': sp2})
# op = ift.NullOperator(mdom1, mdom2)
# ift.extra.consistency_check(op, dtype, dtype)
# op = ift.NullOperator(sp1, mdom2)
# ift.extra.consistency_check(op, dtype, dtype)
# op = ift.NullOperator(mdom1, sp2)
# ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment