Commit 6fb90ba4 authored by Martin Reinecke's avatar Martin Reinecke

add BlockDiagonalOperator

parent 23b0ed81
Pipeline #29462 passed with stages
in 4 minutes and 14 seconds
......@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
def _logger_init():
import logging
from . import dobj
......
from .multi_domain import MultiDomain
from .multi_field import MultiField
from .block_diagonal_operator import BlockDiagonalOperator
__all__ = ["MultiDomain", "MultiField"]
__all__ = ["MultiDomain", "MultiField", "BlockDiagonalOperator"]
import numpy as np
from ..operators.endomorphic_operator import EndomorphicOperator
from .multi_domain import MultiDomain
from .multi_field import MultiField
class BlockDiagonalOperator(EndomorphicOperator):
def __init__(self, operators):
"""
Parameters
----------
operators : dict
dictionary with operators domain names as keys and
LinearOperators as items
"""
super(BlockDiagonalOperator, self).__init__()
self._operators = operators
self._domain = MultiDomain(
{key: op.domain for key, op in self._operators.items()})
self._cap = self._all_ops
for op in self._operators.values():
self._cap &= op.capability
@property
def domain(self):
return self._domain
@property
def capability(self):
return self._cap
def apply(self, x, mode):
self._check_input(x, mode)
return MultiField({key: op.apply(x[key], mode=mode)
for key, op in self._operators.items()})
def draw_sample(self, from_inverse=False, dtype=np.float64):
dtype = MultiField.build_dtype(dtype, self._domain)
return MultiField({key: op.draw_sample(from_inverse, dtype[key])
for key, op in self._operators.items()})
def _combine_chain(self, op):
res = {}
for key in self._operators.keys():
res[key] = self._operators[key]*op._operators[key]
return res
def _combine_sum(self, op, selfneg, opneg):
res = {}
for key in self._operators.keys():
res[key] = SumOperator.make([self._operators[key],
op._operators[key]],
[selfneg, opneg])
return res
......@@ -78,6 +78,17 @@ class ChainOperator(LinearOperator):
else:
opsnew.append(op)
ops = opsnew
# Step 5: combine BlockDiagonalOperators where possible
from ..multi.block_diagonal_operator import BlockDiagonalOperator
opsnew = []
for op in ops:
if (len(opsnew) > 0 and
isinstance(opsnew[-1], BlockDiagonalOperator) and
isinstance(op, BlockDiagonalOperator)):
opsnew[-1] = opsnew[-1]._combine_chain(op)
else:
opsnew.append(op)
ops = opsnew
return ops
@staticmethod
......
......@@ -102,6 +102,28 @@ class SumOperator(LinearOperator):
negnew.append(neg[i])
ops = opsnew
neg = negnew
# Step 5: combine BlockDiagonalOperators where possible
from ..multi.block_diagonal_operator import BlockDiagonalOperator
processed = [False] * len(ops)
opsnew = []
negnew = []
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], BlockDiagonalOperator):
op = ops[i]
opneg = neg[i]
for j in range(i+1, len(ops)):
if isinstance(ops[j], BlockDiagonalOperator):
op = op._combine_sum(ops[j], opneg, neg[j])
opneg = False
processed[j] = True
opsnew.append(op)
negnew.append(opneg)
else:
opsnew.append(ops[i])
negnew.append(neg[i])
ops = opsnew
neg = negnew
return ops, neg
@staticmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment