jax_operator.py 8.84 KB
Newer Older
Philipp Arras's avatar
Philipp Arras committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras

Philipp Arras's avatar
Philipp Arras committed
17
18
from types import SimpleNamespace

Philipp Arras's avatar
Philipp Arras committed
19
20
import numpy as np

Philipp Arras's avatar
Philipp Arras committed
21
22
23
from .energy_operators import LikelihoodEnergyOperator
from .linear_operator import LinearOperator
from .operator import Operator
Philipp Arras's avatar
Philipp Arras committed
24
25
26
27

try:
    import jax
    jax.config.update("jax_enable_x64", True)
Philipp Arras's avatar
Philipp Arras committed
28
    __all__ = ["JaxOperator", "JaxLikelihoodEnergyOperator", "JaxLinearOperator"]
Philipp Arras's avatar
Philipp Arras committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
except ImportError:
    __all__ = []

def _jax2np(obj):
    if isinstance(obj, dict):
        return {kk: np.array(vv) for kk, vv in obj.items()}
    return np.array(obj)


class JaxOperator(Operator):
    """Wrap a jax function as nifty operator.

    Parameters
    ----------
    domain : DomainTuple or MultiDomain
        Domain of the operator.

    target : DomainTuple or MultiDomain
        Target of the operator.

49
    func : callable
Philipp Arras's avatar
Philipp Arras committed
50
51
        The jax function that is evaluated by the operator. It has to be
        implemented in terms of `jax.numpy` calls. If `domain` is a
Philipp Arras's avatar
Philipp Arras committed
52
        `MultiDomain`, `func` takes a `dict` as argument and like-wise for the
Philipp Arras's avatar
Philipp Arras committed
53
54
55
56
57
58
59
        target.
    """
    def __init__(self, domain, target, func):
        from ..sugar import makeDomain
        self._domain = makeDomain(domain)
        self._target = makeDomain(target)
        self._func = jax.jit(func)
60
        self._vjp = jax.jit(lambda x: jax.vjp(func, x))
Philipp Arras's avatar
Philipp Arras committed
61
        self._fwd = jax.jit(lambda x, y: jax.jvp(self._func, (x,), (y,))[1])
Philipp Arras's avatar
Philipp Arras committed
62
63

    def apply(self, x):
64
        from ..multi_domain import MultiDomain
Philipp Arras's avatar
Philipp Arras committed
65
        from ..sugar import is_linearization, makeField
Philipp Arras's avatar
Philipp Arras committed
66
67
        self._check_input(x)
        if is_linearization(x):
68
            res, bwd = self._vjp(x.val.val)
Philipp Arras's avatar
Philipp Arras committed
69
            fwd = lambda y: self._fwd(x.val.val, y)
Philipp Arras's avatar
Philipp Arras committed
70
            jac = JaxLinearOperator(self._domain, self._target, fwd, func_T=bwd)
Philipp Arras's avatar
Philipp Arras committed
71
            return x.new(makeField(self._target, _jax2np(res)), jac)
72
73
74
        res = _jax2np(self._func(x.val))
        if isinstance(res, dict):
            if not isinstance(self._target, MultiDomain):
Martin Reinecke's avatar
Typo    
Martin Reinecke committed
75
                raise TypeError(("Jax function returns a dictionary although the "
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
                                 "target of the operator is a DomainTuple."))
            if set(res.keys()) != set(self._target.keys()):
                raise ValueError(("Keys do not match:\n"
                                  f"Target keys: {self._target.keys()}\n"
                                  f"Jax function returns: {res.keys()}"))
            for kk in res.keys():
                self._check_shape(self._target[kk].shape, res[kk].shape)
        else:
            if isinstance(self._target, MultiDomain):
                raise TypeError(("Jax function does not return a dictionary "
                                 "although the target of the operator is a "
                                 "MultiDomain."))
            self._check_shape(self._target.shape, res.shape)
        return makeField(self._target, res)

    @staticmethod
    def _check_shape(shp_tgt, shp_jax):
        if shp_tgt != shp_jax:
            raise ValueError(("Output shapes do not match:\n"
                             f"Target shape is\t\t{shp_tgt}\n"
                             f"Jax function returns\t{shp_jax}"))
Philipp Arras's avatar
Philipp Arras committed
97
98
99
100
101
102
103
104

    def _simplify_for_constant_input_nontrivial(self, c_inp):
        func2 = lambda x: self._func({**x, **c_inp.val})
        dom = {kk: vv for kk, vv in self._domain.items()
                if kk not in c_inp.keys()}
        return None, JaxOperator(dom, self._target, func2)


Philipp Arras's avatar
Philipp Arras committed
105
class JaxLinearOperator(LinearOperator):
Philipp Arras's avatar
Philipp Arras committed
106
107
108
109
110
111
112
113
114
115
116
117
118
    """Wrap a jax function as nifty linear operator.

    Parameters
    ----------
    domain : DomainTuple or MultiDomain
        Domain of the operator.

    target : DomainTuple or MultiDomain
        Target of the operator.

    func : callable
        The jax function that is evaluated by the operator. It has to be
        implemented in terms of `jax.numpy` calls. If `domain` is a
Philipp Arras's avatar
Philipp Arras committed
119
        `MultiDomain`, `func` takes a `dict` as argument and like-wise for the
Philipp Arras's avatar
Philipp Arras committed
120
121
        target.

Philipp Arras's avatar
Philipp Arras committed
122
123
124
125
126
    func_T : callable
        The jax function that implements the transposed action of the operator.
        If None, jax computes the adjoint. Note that this is *not* the adjoint
        action. Default: None.

Philipp Arras's avatar
Philipp Arras committed
127
    domain_dtype:
Philipp Arras's avatar
Philipp Arras committed
128
129
130
        Needs to be set if `func_transposed` is None. Otherwise it does not have
        an effect. Dtype of the domain. If `domain` is a `MultiDomain`,
        `domain_dtype` is supposed to be a dictionary. Default: None.
Philipp Arras's avatar
Philipp Arras committed
131
132
133
134

    Note
    ----
    It is the user's responsibility that func is actually a linear function. The
Philipp Arras's avatar
Philipp Arras committed
135
136
    user can double check this with the help of
    `nifty8.extra.check_linear_operator`.
Philipp Arras's avatar
Philipp Arras committed
137
    """
Philipp Arras's avatar
Philipp Arras committed
138
139
140
141
142
143
    def __init__(self, domain, target, func, domain_dtype=None, func_T=None):
        from ..domain_tuple import DomainTuple
        from ..sugar import makeDomain
        domain = makeDomain(domain)
        if domain_dtype is not None and func_T is None:
            if isinstance(domain, DomainTuple):
Philipp Arras's avatar
Philipp Arras committed
144
                inp = SimpleNamespace(shape=domain.shape, dtype=domain_dtype)
Philipp Arras's avatar
Philipp Arras committed
145
            else:
Philipp Arras's avatar
Philipp Arras committed
146
147
                inp = {kk: SimpleNameSpace(shape=domain[kk].shape, dtype=domain_dtype[kk])
                       for kk in domain.keys()}
Philipp Arras's avatar
Philipp Arras committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
            func_T = jax.jit(jax.linear_transpose(func, inp))
        elif domain_dtype is None and func_T is not None:
            pass
        else:
            raise ValueError("Either domain_dtype or func_T have to be not None.")
        self._domain = makeDomain(domain)
        self._target = makeDomain(target)
        self._func = func
        self._func_T = func_T
        self._capability = self.TIMES | self.ADJOINT_TIMES

    def apply(self, x, mode):
        from ..sugar import makeField
        self._check_input(x, mode)
        if mode == self.TIMES:
            fx = self._func(x.val)
164
            return makeField(self._target, _jax2np(fx))
Philipp Arras's avatar
Philipp Arras committed
165
        fx = self._func_T(x.conjugate().val)[0]
166
        return makeField(self._domain, _jax2np(fx)).conjugate()
Philipp Arras's avatar
Philipp Arras committed
167
168


Philipp Arras's avatar
Philipp Arras committed
169
class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
Philipp Arras's avatar
Philipp Arras committed
170
171
172
173
174
175
176
177
178
179
    """Wrap a jax function as nifty likelihood energy operator.

    Parameters
    ----------
    domain : DomainTuple or MultiDomain
        Domain of the operator.

    func : callable
        The jax function that is evaluated by the operator. It has to be
        implemented in terms of `jax.numpy` calls. If `domain` is a
Philipp Arras's avatar
Philipp Arras committed
180
        `MultiDomain`, `func` takes a `dict` as argument and like-wise for the
Philipp Arras's avatar
Philipp Arras committed
181
182
183
184
185
186
187
188
189
        target. It needs to map to a scalar.

    transformation : Operator, optional
        Coordinate transformation to Euclidean space.

    sampling_dtype : dtype, optional
        The dtype that shall be used for drawing samples from the metric of the
        likelihood.
    """
Philipp Arras's avatar
Philipp Arras committed
190
191
192
193
    def __init__(self, domain, func, transformation=None, sampling_dtype=None):
        from ..sugar import makeDomain
        self._domain = makeDomain(domain)
        self._func = jax.jit(func)
194
        self._val_and_grad = jax.jit(jax.value_and_grad(func))
Philipp Arras's avatar
Philipp Arras committed
195
196
197
198
199
200
201
202
203
204
        self._dt = sampling_dtype
        self._trafo = transformation

    def get_transformation(self):
        if self._trafo is None:
            s = self.__name__ + " was instantiated without `transformation`"
            raise RuntimeError(s)
        return self._dt, self._trafo

    def apply(self, x):
Philipp Arras's avatar
Philipp Arras committed
205
        from ..linearization import Linearization
Philipp Arras's avatar
Philipp Arras committed
206
207
208
209
210
211
        from ..sugar import is_linearization, makeField
        from .simple_linear_operators import VdotOperator
        self._check_input(x)
        lin = is_linearization(x)
        val = x.val.val if lin else x.val
        if not lin:
212
213
214
215
            return makeField(self._target, _jax2np(self._func(val)))
        res, grad = self._val_and_grad(val)
        jac = VdotOperator(makeField(self._domain, _jax2np(grad)))
        res = x.new(makeField(self._target, _jax2np(res)), jac)
Philipp Arras's avatar
Philipp Arras committed
216
217
218
219
        if not x.want_metric:
            return res
        return res.add_metric(self.get_metric_at(x.val))

Philipp Arras's avatar
Philipp Arras committed
220
221
222
223
    def _simplify_for_constant_input_nontrivial(self, c_inp):
        func2 = lambda x: self._func({**x, **c_inp.val})
        dom = {kk: vv for kk, vv in self._domain.items()
                if kk not in c_inp.keys()}
Philipp Arras's avatar
Philipp Arras committed
224
225
226
227
228
229
        _, trafo = self._trafo.simplify_for_constant_input(c_inp)
        if isinstance(self._dt, dict):
            dt = {kk: self._dt[kk] for kk in dom.keys()}
        else:
            dt = self._dt
        return None, JaxLikelihoodEnergyOperator(dom, func2, trafo, dt)