Skip to content
Snippets Groups Projects
Commit 6fb90ba4 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

add BlockDiagonalOperator

parent 23b0ed81
No related branches found
No related tags found
No related merge requests found
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes. # and financially supported by the Studienstiftung des deutschen Volkes.
def _logger_init(): def _logger_init():
import logging import logging
from . import dobj from . import dobj
......
from .multi_domain import MultiDomain from .multi_domain import MultiDomain
from .multi_field import MultiField from .multi_field import MultiField
from .block_diagonal_operator import BlockDiagonalOperator
__all__ = ["MultiDomain", "MultiField"] __all__ = ["MultiDomain", "MultiField", "BlockDiagonalOperator"]
import numpy as np
from ..operators.endomorphic_operator import EndomorphicOperator
from .multi_domain import MultiDomain
from .multi_field import MultiField
class BlockDiagonalOperator(EndomorphicOperator):
def __init__(self, operators):
"""
Parameters
----------
operators : dict
dictionary with operators domain names as keys and
LinearOperators as items
"""
super(BlockDiagonalOperator, self).__init__()
self._operators = operators
self._domain = MultiDomain(
{key: op.domain for key, op in self._operators.items()})
self._cap = self._all_ops
for op in self._operators.values():
self._cap &= op.capability
@property
def domain(self):
return self._domain
@property
def capability(self):
return self._cap
def apply(self, x, mode):
self._check_input(x, mode)
return MultiField({key: op.apply(x[key], mode=mode)
for key, op in self._operators.items()})
def draw_sample(self, from_inverse=False, dtype=np.float64):
dtype = MultiField.build_dtype(dtype, self._domain)
return MultiField({key: op.draw_sample(from_inverse, dtype[key])
for key, op in self._operators.items()})
def _combine_chain(self, op):
res = {}
for key in self._operators.keys():
res[key] = self._operators[key]*op._operators[key]
return res
def _combine_sum(self, op, selfneg, opneg):
res = {}
for key in self._operators.keys():
res[key] = SumOperator.make([self._operators[key],
op._operators[key]],
[selfneg, opneg])
return res
...@@ -78,6 +78,17 @@ class ChainOperator(LinearOperator): ...@@ -78,6 +78,17 @@ class ChainOperator(LinearOperator):
else: else:
opsnew.append(op) opsnew.append(op)
ops = opsnew ops = opsnew
# Step 5: combine BlockDiagonalOperators where possible
from ..multi.block_diagonal_operator import BlockDiagonalOperator
opsnew = []
for op in ops:
if (len(opsnew) > 0 and
isinstance(opsnew[-1], BlockDiagonalOperator) and
isinstance(op, BlockDiagonalOperator)):
opsnew[-1] = opsnew[-1]._combine_chain(op)
else:
opsnew.append(op)
ops = opsnew
return ops return ops
@staticmethod @staticmethod
......
...@@ -102,6 +102,28 @@ class SumOperator(LinearOperator): ...@@ -102,6 +102,28 @@ class SumOperator(LinearOperator):
negnew.append(neg[i]) negnew.append(neg[i])
ops = opsnew ops = opsnew
neg = negnew neg = negnew
# Step 5: combine BlockDiagonalOperators where possible
from ..multi.block_diagonal_operator import BlockDiagonalOperator
processed = [False] * len(ops)
opsnew = []
negnew = []
for i in range(len(ops)):
if not processed[i]:
if isinstance(ops[i], BlockDiagonalOperator):
op = ops[i]
opneg = neg[i]
for j in range(i+1, len(ops)):
if isinstance(ops[j], BlockDiagonalOperator):
op = op._combine_sum(ops[j], opneg, neg[j])
opneg = False
processed[j] = True
opsnew.append(op)
negnew.append(opneg)
else:
opsnew.append(ops[i])
negnew.append(neg[i])
ops = opsnew
neg = negnew
return ops, neg return ops, neg
@staticmethod @staticmethod
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment