Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
241ce41f
Commit
241ce41f
authored
Jun 30, 2021
by
Philipp Arras
Browse files
Add jax likelihood operator
parent
441fffe6
Changes
2
Hide whitespace changes
Inline
Side-by-side
src/operators/jax_operator.py
View file @
241ce41f
...
...
@@ -16,13 +16,15 @@
import
numpy
as
np
from
.operator
import
Operator
from
.energy_operators
import
EnergyOperator
,
LikelihoodEnergyOperator
from
.linear_operator
import
LinearOperator
from
.endomorphic_operator
import
EndomorphicOperator
try
:
import
jax
jax
.
config
.
update
(
"jax_enable_x64"
,
True
)
__all__
=
[
"JaxOperator"
]
__all__
=
[
"JaxOperator"
,
"JaxLikelihoodEnergyOperator"
]
except
ImportError
:
__all__
=
[]
...
...
@@ -74,6 +76,37 @@ class JaxOperator(Operator):
return
None
,
JaxOperator
(
dom
,
self
.
_target
,
func2
)
class
JaxLikelihoodEnergyOperator
(
LikelihoodEnergyOperator
):
def
__init__
(
self
,
domain
,
func
,
transformation
=
None
,
sampling_dtype
=
None
):
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
self
.
_func
=
jax
.
jit
(
func
)
self
.
_grad
=
jax
.
jit
(
jax
.
grad
(
func
))
self
.
_dt
=
sampling_dtype
self
.
_trafo
=
transformation
def
get_transformation
(
self
):
if
self
.
_trafo
is
None
:
s
=
self
.
__name__
+
" was instantiated without `transformation`"
raise
RuntimeError
(
s
)
return
self
.
_dt
,
self
.
_trafo
def
apply
(
self
,
x
):
from
..sugar
import
is_linearization
,
makeField
from
.simple_linear_operators
import
VdotOperator
from
..linearization
import
Linearization
self
.
_check_input
(
x
)
lin
=
is_linearization
(
x
)
val
=
x
.
val
.
val
if
lin
else
x
.
val
res
=
makeField
(
self
.
_target
,
_jax2np
(
self
.
_func
(
val
)))
if
not
lin
:
return
res
jac
=
VdotOperator
(
makeField
(
self
.
_domain
,
_jax2np
(
self
.
_grad
(
val
))))
res
=
Linearization
(
res
,
jac
)
if
not
x
.
want_metric
:
return
res
return
res
.
add_metric
(
self
.
get_metric_at
(
x
.
val
))
class
_JaxJacobian
(
LinearOperator
):
def
__init__
(
self
,
domain
,
target
,
func
,
adjfunc
):
...
...
test/test_operators/test_jax.py
View file @
241ce41f
...
...
@@ -62,6 +62,26 @@ def test_mf_jax():
ift
.
extra
.
check_operator
(
op
,
loc
)
def
test_jax_energy
():
if
_skip
:
pytest
.
skip
()
dom
=
ift
.
UnstructuredDomain
((
10
,
2
))
e0
=
ift
.
GaussianEnergy
(
domain
=
dom
)
def
func
(
x
):
return
0.5
*
jnp
.
vdot
(
x
,
x
)
e
=
ift
.
JaxLikelihoodEnergyOperator
(
dom
,
func
,
transformation
=
ift
.
ScalingOperator
(
dom
,
1.
))
for
wm
in
[
False
,
True
]:
pos
=
ift
.
from_random
(
e
.
domain
)
lin
=
ift
.
Linearization
.
make_var
(
pos
,
wm
)
ift
.
extra
.
assert_allclose
(
e0
(
pos
),
e
(
pos
))
ift
.
extra
.
assert_allclose
(
e0
(
lin
).
val
,
e
(
lin
).
val
)
ift
.
extra
.
assert_allclose
(
e0
(
lin
).
gradient
,
e
(
lin
).
gradient
)
if
not
wm
:
continue
pos1
=
ift
.
from_random
(
e
.
domain
)
ift
.
extra
.
assert_allclose
(
e0
(
lin
).
metric
(
pos1
),
e
(
lin
).
metric
(
pos1
))
def
test_cf
():
dom
=
ift
.
RGSpace
([
13
,
14
],
distances
=
(
0.89
,
0.9
))
args
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment