Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
aea0988c
Commit
aea0988c
authored
Aug 09, 2021
by
Philipp Arras
Browse files
Merge JaxLinearOperators
parent
d8af23d6
Pipeline
#107326
failed with stages
in 9 minutes and 41 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/operators/jax_operator.py
View file @
aea0988c
...
...
@@ -65,7 +65,7 @@ class JaxOperator(Operator):
if
is_linearization
(
x
):
res
,
bwd
=
self
.
_vjp
(
x
.
val
.
val
)
fwd
=
lambda
y
:
self
.
_fwd
(
x
.
val
.
val
,
y
)
jac
=
_
JaxLinearOperator
(
self
.
_domain
,
self
.
_target
,
fwd
,
bwd
)
jac
=
JaxLinearOperator
(
self
.
_domain
,
self
.
_target
,
fwd
,
func_T
=
bwd
)
return
x
.
new
(
makeField
(
self
.
_target
,
_jax2np
(
res
)),
jac
)
res
=
_jax2np
(
self
.
_func
(
x
.
val
))
if
isinstance
(
res
,
dict
):
...
...
@@ -100,7 +100,7 @@ class JaxOperator(Operator):
return
None
,
JaxOperator
(
dom
,
self
.
_target
,
func2
)
def
JaxLinearOperator
(
domain
,
target
,
func
,
domain_dtype
):
class
JaxLinearOperator
(
LinearOperator
):
"""Wrap a jax function as nifty linear operator.
Parameters
...
...
@@ -117,24 +117,50 @@ def JaxLinearOperator(domain, target, func, domain_dtype):
`MultiDomain`, `func` takes a `dict` as argument and like-wise for the
target.
func_T : callable
The jax function that implements the transposed action of the operator.
If None, jax computes the adjoint. Note that this is *not* the adjoint
action. Default: None.
domain_dtype:
Dtype of the domain. If `domain` is a `MultiDomain`, `domain_dtype` is
supposed to be a dictionary.
Needs to be set if `func_transposed` is None. Otherwise it does not have
an effect. Dtype of the domain. If `domain` is a `MultiDomain`,
`domain_dtype` is supposed to be a dictionary. Default: None.
Note
----
It is the user's responsibility that func is actually a linear function. The
user can double check this with the help of `nifty8.extra.check_linear_operator`.
user can double check this with the help of
`nifty8.extra.check_linear_operator`.
"""
from
..domain_tuple
import
DomainTuple
from
..sugar
import
makeDomain
domain
=
makeDomain
(
domain
)
if
isinstance
(
domain
,
DomainTuple
):
inp
=
np
.
empty
(
domain
.
shape
,
domain_dtype
)
else
:
inp
=
{
kk
:
np
.
empty
(
domain
[
kk
].
shape
,
domain_dtype
[
kk
])
for
kk
in
domain
.
keys
()}
func_transposed
=
jax
.
jit
(
jax
.
linear_transpose
(
func
,
inp
))
return
_JaxLinearOperator
(
domain
,
target
,
func
,
func_transposed
)
def
__init__
(
self
,
domain
,
target
,
func
,
domain_dtype
=
None
,
func_T
=
None
):
from
..domain_tuple
import
DomainTuple
from
..sugar
import
makeDomain
domain
=
makeDomain
(
domain
)
if
domain_dtype
is
not
None
and
func_T
is
None
:
if
isinstance
(
domain
,
DomainTuple
):
inp
=
np
.
empty
(
domain
.
shape
,
domain_dtype
)
else
:
inp
=
{
kk
:
np
.
empty
(
domain
[
kk
].
shape
,
domain_dtype
[
kk
])
for
kk
in
domain
.
keys
()}
func_T
=
jax
.
jit
(
jax
.
linear_transpose
(
func
,
inp
))
elif
domain_dtype
is
None
and
func_T
is
not
None
:
pass
else
:
raise
ValueError
(
"Either domain_dtype or func_T have to be not None."
)
self
.
_domain
=
makeDomain
(
domain
)
self
.
_target
=
makeDomain
(
target
)
self
.
_func
=
func
self
.
_func_T
=
func_T
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
from
..sugar
import
makeField
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
fx
=
self
.
_func
(
x
.
val
)
return
makeField
(
self
.
_domain
,
_jax2np
(
fx
))
fx
=
self
.
_func_T
(
x
.
conjugate
().
val
)[
0
]
return
makeField
(
self
.
_target
,
_jax2np
(
fx
)).
conjugate
()
class
JaxLikelihoodEnergyOperator
(
LikelihoodEnergyOperator
):
...
...
@@ -198,22 +224,3 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
else
:
dt
=
self
.
_dt
return
None
,
JaxLikelihoodEnergyOperator
(
dom
,
func2
,
trafo
,
dt
)
class
_JaxLinearOperator
(
LinearOperator
):
def
__init__
(
self
,
domain
,
target
,
func
,
func_transposed
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
self
.
_target
=
makeDomain
(
target
)
self
.
_func
=
func
self
.
_func_transposed
=
func_transposed
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
from
..sugar
import
makeField
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
fx
=
self
.
_func
(
x
.
val
)
return
makeField
(
self
.
_tgt
(
mode
),
_jax2np
(
fx
))
fx
=
self
.
_func_transposed
(
x
.
conjugate
().
val
)[
0
]
return
makeField
(
self
.
_tgt
(
mode
),
_jax2np
(
fx
)).
conjugate
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment