simplify_for_const.py 4.6 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.

from ..multi_domain import MultiDomain
Philipp Arras's avatar
Philipp Arras committed
19
from .block_diagonal_operator import BlockDiagonalOperator
Philipp Arras's avatar
Philipp Arras committed
20
21
from .energy_operators import EnergyOperator
from .operator import Operator
Philipp Arras's avatar
Philipp Arras committed
22
from .scaling_operator import ScalingOperator
Philipp Arras's avatar
Philipp Arras committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from .simple_linear_operators import NullOperator


class ConstCollector(object):
    def __init__(self):
        self._const = None
        self._nc = set()

    def mult(self, const, fulldom):
        if const is None:
            self._nc |= set(fulldom)
        else:
            self._nc |= set(fulldom) - set(const)
            if self._const is None:
                from ..multi_field import MultiField
                self._const = MultiField.from_dict(
                    {key: const[key] for key in const if key not in self._nc})
            else:
                from ..multi_field import MultiField
                self._const = MultiField.from_dict(
                    {key: self._const[key]*const[key]
                     for key in const if key not in self._nc})

    def add(self, const, fulldom):
        if const is None:
            self._nc |= set(fulldom.keys())
        else:
            from ..multi_field import MultiField
            self._nc |= set(fulldom.keys()) - set(const.keys())
            if self._const is None:
                self._const = MultiField.from_dict(
                    {key: const[key]
                     for key in const.keys() if key not in self._nc})
            else:
                self._const = self._const.unite(const)
                self._const = MultiField.from_dict(
                    {key: self._const[key]
                     for key in self._const if key not in self._nc})

    @property
    def constfield(self):
        return self._const


class ConstantOperator(Operator):
    def __init__(self, dom, output):
        from ..sugar import makeDomain
        self._domain = makeDomain(dom)
        self._target = output.domain
        self._output = output

    def apply(self, x):
        from .simple_linear_operators import NullOperator
        self._check_input(x)
        if x.jac is not None:
            return x.new(self._output, NullOperator(self._domain, self._target))
        return self._output

    def __repr__(self):
        dom = self.domain.keys() if isinstance(self.domain, MultiDomain) else '()'
        tgt = self.target.keys() if isinstance(self.target, MultiDomain) else '()'
        return f'{tgt} <- ConstantOperator <- {dom}'


Philipp Arras's avatar
Philipp Arras committed
87
88
89
90
91
class SlowPartialConstantOperator(Operator):
    def __init__(self, domain, constant_keys):
        from ..sugar import makeDomain
        if not isinstance(domain, MultiDomain):
            raise TypeError
Philipp Arras's avatar
Philipp Arras committed
92
        if set(constant_keys) > set(domain.keys()) or len(constant_keys) == 0:
Philipp Arras's avatar
Philipp Arras committed
93
            raise ValueError
Philipp Arras's avatar
Philipp Arras committed
94
        self._keys = set(constant_keys) & set(domain.keys())
Philipp Arras's avatar
Philipp Arras committed
95
        self._domain = self._target = makeDomain(domain)
Philipp Arras's avatar
Philipp Arras committed
96

Philipp Arras's avatar
Philipp Arras committed
97
98
99
100
    def apply(self, x):
        self._check_input(x)
        if x.jac is None:
            return x
Philipp Arras's avatar
Philipp Arras committed
101
102
        jac = {kk: ScalingOperator(dd, 0 if kk in self._keys else 1)
               for kk, dd in self._domain.items()}
Philipp Arras's avatar
Philipp Arras committed
103
        return x.prepend_jac(BlockDiagonalOperator(x.jac.domain, jac))
Philipp Arras's avatar
Philipp Arras committed
104

Philipp Arras's avatar
Philipp Arras committed
105
106
    def __repr__(self):
        return f'SlowPartialConstantOperator ({self._keys})'
Philipp Arras's avatar
Philipp Arras committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127


class ConstantEnergyOperator(EnergyOperator):
    def __init__(self, dom, output):
        from ..sugar import makeDomain
        self._domain = makeDomain(dom)
        if self.target is not output.domain:
            raise TypeError
        self._output = output

    def apply(self, x):
        self._check_input(x)
        if x.jac is not None:
            val = self._output
            jac = NullOperator(self._domain, self._target)
            met = NullOperator(self._domain, self._domain) if x.want_metric else None
            return x.new(val, jac, met)
        return self._output

    def __repr__(self):
        return 'ConstantEnergyOperator <- {}'.format(self.domain.keys())