From f0f63c40e51db13fd8f76f35932ba8c6c2014c28 Mon Sep 17 00:00:00 2001 From: Philipp Haim Date: Wed, 25 Sep 2019 14:17:54 +0200 Subject: [PATCH] If None is passed as operator, it will be treated as unity --- nifty5/operators/block_diagonal_operator.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/nifty5/operators/block_diagonal_operator.py b/nifty5/operators/block_diagonal_operator.py index fef27414..e8b8ac9f 100644 --- a/nifty5/operators/block_diagonal_operator.py +++ b/nifty5/operators/block_diagonal_operator.py @@ -22,6 +22,7 @@ from ..multi_field import MultiField from .endomorphic_operator import EndomorphicOperator + class BlockDiagonalOperator(EndomorphicOperator): """ Parameters @@ -30,7 +31,7 @@ class BlockDiagonalOperator(EndomorphicOperator): Domain and target of the operator. operators : dict Dictionary with subdomain names as keys and :class:`LinearOperator` s - as items. + as items. Any item None will be treated as unity operator. """ def __init__(self, domain, operators): if not isinstance(domain, MultiDomain): @@ -44,13 +45,15 @@ class BlockDiagonalOperator(EndomorphicOperator): def apply(self, x, mode): self._check_input(x, mode) - val = tuple(op.apply(v, mode=mode) if op is not None else None + val = tuple(op.apply(v, mode=mode) if op is not None else v for op, v in zip(self._ops, x.values())) return MultiField(self._domain, val) def draw_sample(self, from_inverse=False, dtype=np.float64): + from ..sugar import from_random val = tuple(op.draw_sample(from_inverse, dtype) - if op is not None else None for op in self._ops) + if op is not None else from_random('normal', self._domain[key], dtype=dtype) + for op, key in zip(self._ops, self._domain.keys())) return MultiField(self._domain, val) def _combine_chain(self, op): -- GitLab