Commit f0f63c40 authored by Philipp Haim's avatar Philipp Haim

If None is passed as operator, it will be treated as unity

parent 6e515642
...@@ -22,6 +22,7 @@ from ..multi_field import MultiField ...@@ -22,6 +22,7 @@ from ..multi_field import MultiField
from .endomorphic_operator import EndomorphicOperator from .endomorphic_operator import EndomorphicOperator
class BlockDiagonalOperator(EndomorphicOperator): class BlockDiagonalOperator(EndomorphicOperator):
""" """
Parameters Parameters
...@@ -30,7 +31,7 @@ class BlockDiagonalOperator(EndomorphicOperator): ...@@ -30,7 +31,7 @@ class BlockDiagonalOperator(EndomorphicOperator):
Domain and target of the operator. Domain and target of the operator.
operators : dict operators : dict
Dictionary with subdomain names as keys and :class:`LinearOperator` s Dictionary with subdomain names as keys and :class:`LinearOperator` s
as items. as items. Any item None will be treated as unity operator.
""" """
def __init__(self, domain, operators): def __init__(self, domain, operators):
if not isinstance(domain, MultiDomain): if not isinstance(domain, MultiDomain):
...@@ -44,13 +45,15 @@ class BlockDiagonalOperator(EndomorphicOperator): ...@@ -44,13 +45,15 @@ class BlockDiagonalOperator(EndomorphicOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
val = tuple(op.apply(v, mode=mode) if op is not None else None val = tuple(op.apply(v, mode=mode) if op is not None else v
for op, v in zip(self._ops, x.values())) for op, v in zip(self._ops, x.values()))
return MultiField(self._domain, val) return MultiField(self._domain, val)
def draw_sample(self, from_inverse=False, dtype=np.float64): def draw_sample(self, from_inverse=False, dtype=np.float64):
from ..sugar import from_random
val = tuple(op.draw_sample(from_inverse, dtype) val = tuple(op.draw_sample(from_inverse, dtype)
if op is not None else None for op in self._ops) if op is not None else from_random('normal', self._domain[key], dtype=dtype)
for op, key in zip(self._ops, self._domain.keys()))
return MultiField(self._domain, val) return MultiField(self._domain, val)
def _combine_chain(self, op): def _combine_chain(self, op):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment