Commit bbdb9944 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

first steps

parent ca2ceecf
Pipeline #28363 passed with stages
in 2 minutes and 37 seconds
...@@ -24,6 +24,9 @@ from .utilities import memo ...@@ -24,6 +24,9 @@ from .utilities import memo
from .logger import logger from .logger import logger
from .multi import *
__all__ = ["__version__", "dobj", "DomainTuple"] + \ __all__ = ["__version__", "dobj", "DomainTuple"] + \
domains.__all__ + operators.__all__ + minimization.__all__ + \ domains.__all__ + operators.__all__ + minimization.__all__ + \
["DomainTuple", "Field", "sqrt", "exp", "log"] ["DomainTuple", "Field", "sqrt", "exp", "log"] + \
multi.__all__
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .multi_linear_operator import MultiLinearOperator
__all__ = ["MultiDomain", "MultiField", "MultiLinearOperator"]
class MultiDomain(dict):
pass
from ..field import Field
import numpy as np
from .multi_domain import MultiDomain
class MultiField(object):
def __init__(self, val):
"""
Parameters
----------
val : dict
"""
self._val = val
def __getitem__(self, key):
return self._val[key]
def keys(self):
return self._val.keys()
def items(self):
return self._val.items()
def values(self):
return self._val.values()
@property
def domain(self):
return MultiDomain({key: val.domain for key, val in self._val.items()})
def _check_domain(self, other):
if other.domain != self.domain:
raise ValueError("domains are incompatible.")
def vdot(self, x):
result = 0.
self._check_domain(x)
for key, sub_field in self.items():
result += sub_field.vdot(x[key])
return result
def lock(self):
for v in self.values():
v.lock()
return self
def copy(self):
return MultiField({key: val.copy() for key, val in self.items()})
@staticmethod
def zeros(domain, dtype=None):
return MultiField({key: Field.zeros(dom, dtype=dtype)
for key, dom in domain.items()})
@staticmethod
def ones(domain, dtype=None):
return MultiField({key: Field.ones(dom, dtype=dtype)
for key, dom in domain.items()})
@staticmethod
def empty(domain, dtype=None):
return MultiField({key: Field.empty(dom, dtype=dtype)
for key, dom in domain.items()})
def norm(self):
""" Computes the L2-norm of the field values.
Returns
-------
norm : float
The L2-norm of the field values.
"""
return np.sqrt(np.abs(self.vdot(x=self)))
def _binary_helper(self, other, op):
if isinstance(other, MultiField):
self._check_domain(other)
result_val = {key: getattr(sub_field,op)(other[key])
for key, sub_field in self.items()}
else:
result_val = {key: getattr(val,op)(other) for key, val in self.items()}
return MultiField(result_val)
def __neg__(self):
return MultiField({key: -val for key, val in self.items()})
def conjugate(self):
return MultiField({key: sub_field.conjugate() for key, sub_field in self.items()})
for op in ["__add__", "__radd__", "__iadd__",
"__sub__", "__rsub__", "__isub__",
"__mul__", "__rmul__", "__imul__",
"__div__", "__rdiv__", "__idiv__",
"__truediv__", "__rtruediv__", "__itruediv__",
"__floordiv__", "__rfloordiv__", "__ifloordiv__",
"__pow__", "__rpow__", "__ipow__",
"__lt__", "__le__", "__gt__", "__ge__", "__eq__", "__ne__"]:
def func(op):
def func2(self, other):
return self._binary_helper(other, op=op)
return func2
setattr(MultiField, op, func(op))
from ..operators.linear_operator import LinearOperator
class MultiLinearOperator(LinearOperator):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment