Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
a3805bbe
Commit
a3805bbe
authored
Jul 13, 2021
by
Philipp Arras
Browse files
JaxOperator: Support complex functions
parent
404027de
Changes
3
Hide whitespace changes
Inline
Side-by-side
src/extra.py
View file @
a3805bbe
...
@@ -364,7 +364,7 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
...
@@ -364,7 +364,7 @@ def _check_nontrivial_constant(op, loc, tol, ntries, only_r_differentiable,
oplin
=
op
(
lin
)
oplin
=
op
(
lin
)
myassert
(
oplin
.
jac
.
target
is
oplin0
.
jac
.
target
)
myassert
(
oplin
.
jac
.
target
is
oplin0
.
jac
.
target
)
rndinp
=
from_random
(
oplin
.
jac
.
target
)
rndinp
=
from_random
(
oplin
.
jac
.
target
,
dtype
=
oplin
.
val
.
dtype
)
assert_allclose
(
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
varloc
.
domain
),
assert_allclose
(
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
varloc
.
domain
),
oplin0
.
jac
.
adjoint
(
rndinp
),
1e-13
,
1e-13
)
oplin0
.
jac
.
adjoint
(
rndinp
),
1e-13
,
1e-13
)
foo
=
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
cstloc
.
domain
)
foo
=
oplin
.
jac
.
adjoint
(
rndinp
).
extract
(
cstloc
.
domain
)
...
...
src/operators/jax_operator.py
View file @
a3805bbe
...
@@ -165,12 +165,12 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
...
@@ -165,12 +165,12 @@ class JaxLikelihoodEnergyOperator(LikelihoodEnergyOperator):
class
_JaxJacobian
(
LinearOperator
):
class
_JaxJacobian
(
LinearOperator
):
def
__init__
(
self
,
domain
,
target
,
func
,
adj
func
):
def
__init__
(
self
,
domain
,
target
,
func
,
func
_transposed
):
from
..sugar
import
makeDomain
from
..sugar
import
makeDomain
self
.
_domain
=
makeDomain
(
domain
)
self
.
_domain
=
makeDomain
(
domain
)
self
.
_target
=
makeDomain
(
target
)
self
.
_target
=
makeDomain
(
target
)
self
.
_func
=
func
self
.
_func
=
func
self
.
_
adj
func
=
adjfunc
self
.
_func
_transposed
=
func_transposed
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
self
.
_capability
=
self
.
TIMES
|
self
.
ADJOINT_TIMES
def
apply
(
self
,
x
,
mode
):
def
apply
(
self
,
x
,
mode
):
...
@@ -178,6 +178,6 @@ class _JaxJacobian(LinearOperator):
...
@@ -178,6 +178,6 @@ class _JaxJacobian(LinearOperator):
self
.
_check_input
(
x
,
mode
)
self
.
_check_input
(
x
,
mode
)
if
mode
==
self
.
TIMES
:
if
mode
==
self
.
TIMES
:
fx
=
self
.
_func
(
x
.
val
)
fx
=
self
.
_func
(
x
.
val
)
else
:
return
makeField
(
self
.
_tgt
(
mode
),
_jax2np
(
fx
))
fx
=
self
.
_
adj
func
(
x
.
val
)[
0
]
fx
=
self
.
_func
_transposed
(
x
.
conjugate
()
.
val
)[
0
]
return
makeField
(
self
.
_tgt
(
mode
),
_jax2np
(
fx
))
return
makeField
(
self
.
_tgt
(
mode
),
_jax2np
(
fx
))
.
conjugate
()
test/test_operators/test_jax.py
View file @
a3805bbe
...
@@ -104,3 +104,41 @@ def test_jax_errors():
...
@@ -104,3 +104,41 @@ def test_jax_errors():
op
=
ift
.
JaxOperator
(
dom
,
mdom
,
lambda
x
:
{
"a"
:
x
[
0
]})
op
=
ift
.
JaxOperator
(
dom
,
mdom
,
lambda
x
:
{
"a"
:
x
[
0
]})
with
pytest
.
raises
(
ValueError
):
with
pytest
.
raises
(
ValueError
):
op
(
fld
)
op
(
fld
)
def
test_jax_complex
():
dom
=
ift
.
UnstructuredDomain
(
1
)
a
=
ift
.
ducktape
(
dom
,
None
,
"a"
)
b
=
ift
.
ducktape
(
dom
,
None
,
"b"
)
op
=
a
.
real
+
1j
*
b
.
real
op1
=
ift
.
JaxOperator
(
op
.
domain
,
op
.
target
,
lambda
x
:
x
[
"a"
]
+
1j
*
x
[
"b"
])
_op_equal
(
op
,
op1
,
ift
.
from_random
(
op
.
domain
))
ift
.
extra
.
check_operator
(
op
,
ift
.
from_random
(
op
.
domain
),
ntries
=
10
)
ift
.
extra
.
check_operator
(
op1
,
ift
.
from_random
(
op
.
domain
),
ntries
=
10
)
op
=
op
.
imag
op1
=
op1
.
imag
_op_equal
(
op
,
op1
,
ift
.
from_random
(
op
.
domain
))
ift
.
extra
.
check_operator
(
op
,
ift
.
from_random
(
op
.
domain
),
ntries
=
10
)
ift
.
extra
.
check_operator
(
op1
,
ift
.
from_random
(
op
.
domain
),
ntries
=
10
)
lin
=
ift
.
Linearization
.
make_var
(
ift
.
from_random
(
op
.
domain
))
test_vec
=
ift
.
full
(
op
.
target
,
1.
)
grad
=
op
(
lin
).
jac
.
adjoint
(
test_vec
)
grad1
=
op1
(
lin
).
jac
.
adjoint
(
test_vec
)
ift
.
extra
.
assert_equal
(
grad
,
grad1
)
ift
.
extra
.
assert_equal
(
grad
,
ift
.
makeField
(
grad
.
domain
,
{
"a"
:
0.
,
"b"
:
1.
}))
def
_op_equal
(
op0
,
op1
,
loc
):
assert
op0
.
domain
is
op1
.
domain
assert
op0
.
target
is
op1
.
target
ift
.
extra
.
assert_allclose
(
op0
(
loc
),
op1
(
loc
))
lin
=
ift
.
Linearization
.
make_var
(
loc
)
res
=
op0
(
lin
)
res1
=
op1
(
lin
)
ift
.
extra
.
assert_allclose
(
res
.
val
,
res1
.
val
)
fld
=
ift
.
from_random
(
op0
.
domain
,
dtype
=
loc
.
dtype
)
ift
.
extra
.
assert_allclose
(
res
.
jac
(
fld
),
res1
.
jac
(
fld
))
fld
=
ift
.
from_random
(
op0
.
target
,
dtype
=
res
.
val
.
dtype
)
ift
.
extra
.
assert_allclose
(
res
.
jac
.
adjoint
(
fld
),
res1
.
jac
.
adjoint
(
fld
))
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment