block_diagonal_operator.py 2.03 KB
Newer Older
Martin Reinecke's avatar
Martin Reinecke committed
1
2
3
4
5
6
7
import numpy as np
from ..operators.endomorphic_operator import EndomorphicOperator
from .multi_domain import MultiDomain
from .multi_field import MultiField


class BlockDiagonalOperator(EndomorphicOperator):
Martin Reinecke's avatar
Martin Reinecke committed
8
    def __init__(self, domain, operators):
Martin Reinecke's avatar
Martin Reinecke committed
9
10
11
12
13
14
15
16
        """
        Parameters
        ----------
        operators : dict
            dictionary with operators domain names as keys and
            LinearOperators as items
        """
        super(BlockDiagonalOperator, self).__init__()
Martin Reinecke's avatar
Martin Reinecke committed
17
18
        self._domain = domain
        self._ops = tuple(operators[key] for key in self.domain.keys())
Martin Reinecke's avatar
Martin Reinecke committed
19
        self._cap = self._all_ops
Martin Reinecke's avatar
Martin Reinecke committed
20
21
22
        for op in self._ops:
            if op is not None:
                self._cap &= op.capability
Martin Reinecke's avatar
Martin Reinecke committed
23
24
25
26
27
28
29
30
31
32
33

    @property
    def domain(self):
        return self._domain

    @property
    def capability(self):
        return self._cap

    def apply(self, x, mode):
        self._check_input(x, mode)
Martin Reinecke's avatar
Martin Reinecke committed
34
35
        val = tuple(op.apply(v, mode=mode) if op is not None else None
                    for op, v in zip(self._ops, x.values()))
Martin Reinecke's avatar
Martin Reinecke committed
36
        return MultiField(self._domain, val)
Martin Reinecke's avatar
Martin Reinecke committed
37

Martin Reinecke's avatar
Martin Reinecke committed
38
39
40
41
42
#    def draw_sample(self, from_inverse=False, dtype=np.float64):
#        dtype = MultiField.build_dtype(dtype, self._domain)
#        val = tuple(op.draw_sample(from_inverse, dtype)
#                    for op in self._op)
#        return MultiField(self._domain, val)
Martin Reinecke's avatar
Martin Reinecke committed
43
44

    def _combine_chain(self, op):
Martin Reinecke's avatar
Martin Reinecke committed
45
46
        if self._domain is not op._domain:
            raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
47
48
        res = {key: v1*v2
               for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops)}
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
49
        return BlockDiagonalOperator(self._domain, res)
Martin Reinecke's avatar
Martin Reinecke committed
50
51

    def _combine_sum(self, op, selfneg, opneg):
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
52
        from ..operators.sum_operator import SumOperator
Martin Reinecke's avatar
Martin Reinecke committed
53
54
        if self._domain is not op._domain:
            raise ValueError("domain mismatch")
Martin Reinecke's avatar
Martin Reinecke committed
55
        res = {}
Martin Reinecke's avatar
Martin Reinecke committed
56
57
        for key, v1, v2 in zip(self._domain.keys(), self._ops, op._ops):
            res[key] = SumOperator.make([v1, v2], [selfneg, opneg])
Martin Reinecke's avatar
fixes    
Martin Reinecke committed
58
        return BlockDiagonalOperator(self._domain, res)