block_diagonal_operator.py 1.89 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import numpy as np
from ..operators.endomorphic_operator import EndomorphicOperator
from .multi_domain import MultiDomain
from .multi_field import MultiField


class BlockDiagonalOperator(EndomorphicOperator):
    def __init__(self, operators):
        """
        Parameters
        ----------
        operators : dict
            dictionary with operators domain names as keys and
            LinearOperators as items
        """
        super(BlockDiagonalOperator, self).__init__()
        self._operators = operators
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
18
        self._domain = MultiDomain.make(
Martin Reinecke's avatar
Martin Reinecke committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
            {key: op.domain for key, op in self._operators.items()})
        self._cap = self._all_ops
        for op in self._operators.values():
            self._cap &= op.capability

    @property
    def domain(self):
        return self._domain

    @property
    def capability(self):
        return self._cap

    def apply(self, x, mode):
        self._check_input(x, mode)
        return MultiField({key: op.apply(x[key], mode=mode)
                           for key, op in self._operators.items()})

    def draw_sample(self, from_inverse=False, dtype=np.float64):
        dtype = MultiField.build_dtype(dtype, self._domain)
        return MultiField({key: op.draw_sample(from_inverse, dtype[key])
                           for key, op in self._operators.items()})

    def _combine_chain(self, op):
        res = {}
        for key in self._operators.keys():
            res[key] = self._operators[key]*op._operators[key]
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
46
        return BlockDiagonalOperator(res)
Martin Reinecke's avatar
Martin Reinecke committed
47
48

    def _combine_sum(self, op, selfneg, opneg):
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
49
        from ..operators.sum_operator import SumOperator
Martin Reinecke's avatar
Martin Reinecke committed
50
51
52
53
54
        res = {}
        for key in self._operators.keys():
            res[key] = SumOperator.make([self._operators[key],
                                         op._operators[key]],
                                        [selfneg, opneg])
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
55
        return BlockDiagonalOperator(res)