Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
febe7090
Commit
febe7090
authored
Jun 10, 2021
by
Philipp Arras
Browse files
Add explanation how to add nonlinearities to NIFTy
parent
2236cf7b
Pipeline
#103346
passed with stages
in 14 minutes and 10 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
.gitlab-ci.yml
View file @
febe7090
...
...
@@ -147,3 +147,8 @@ run_visual_vi:
stage
:
demo_runs
script
:
-
python3 demos/variational_inference_visualized.py
run_nonlinearity_guide
:
stage
:
demo_runs
script
:
-
python3 demos/custom_nonlinearities.py
demos/custom_nonlinearities.py
0 → 100644
View file @
febe7090
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
import
nifty7
as
ift
import
numpy
as
np
# In NIFTy, users can add hand-crafted point-wise nonlinearities that are then
# available for `Field`, `MultiField`, `Linearization` and `Operator`. This
# guide shows an example how this is done.
# Suppose, we would like to use the function f(x) = x*exp(x) point-wise in an
# operator chain. This function is called "myptw" in the following. We
# introduce this function to NIFTy by implementing two functions.
# First, one that takes a `numpy.ndarray` as an input, applies the point-wise
# mapping and returns the result as a `numpy.ndarray` (of the same shape).
# Second, a function that takes an `numpy.ndarray` as an input and returns two
# `numpy.ndarray`s: the application of the nonlinearity (same as before) and
# the derivative.
def
func
(
x
):
return
x
*
np
.
exp
(
x
)
def
func_and_derv
(
x
):
expx
=
np
.
exp
(
x
)
return
x
*
expx
,
(
1
+
x
)
*
expx
# These two functions are then added to the NIFTy-internal dictionary that
# contains all implemented point-wise nonlinearities.
ift
.
pointwise
.
ptw_dict
[
"myptw"
]
=
func
,
func_and_derv
# This allows us to apply this non-linearity on `Field`s, ...
dom
=
ift
.
UnstructuredDomain
(
10
)
fld
=
ift
.
from_random
(
dom
)
fld
=
ift
.
full
(
dom
,
2.
)
a
=
fld
.
ptw
(
"myptw"
)
b
=
ift
.
makeField
(
dom
,
func
(
fld
.
val
))
ift
.
extra
.
assert_allclose
(
a
,
b
)
# `MultiField`s, ...
mdom
=
ift
.
makeDomain
({
"bar"
:
ift
.
UnstructuredDomain
(
10
)})
mfld
=
ift
.
from_random
(
mdom
)
a
=
mfld
.
ptw
(
"myptw"
)
b
=
ift
.
makeField
(
mdom
,
{
"bar"
:
func
(
mfld
[
"bar"
].
val
)})
ift
.
extra
.
assert_allclose
(
a
,
b
)
# Linearizations (including the Jacobian), ...
# (Value)
lin
=
ift
.
Linearization
.
make_var
(
fld
)
a
=
lin
.
ptw
(
"myptw"
).
val
b
=
ift
.
makeField
(
dom
,
func
(
fld
.
val
))
ift
.
extra
.
assert_allclose
(
a
,
b
)
# (Jacobian)
op_a
=
lin
.
ptw
(
"myptw"
).
jac
op_b
=
ift
.
makeOp
(
ift
.
makeField
(
dom
,
func_and_derv
(
fld
.
val
)[
1
]))
testing_vector
=
ift
.
from_random
(
dom
)
ift
.
extra
.
assert_allclose
(
op_a
(
testing_vector
),
op_b
(
testing_vector
))
# and `Operator`s.
op
=
ift
.
FieldAdapter
(
dom
,
"foo"
).
ptw
(
"myptw"
)
# We check that the gradient has been implemented correctly by comparing it to
# an approximation to the gradient by finite differences.
def
check
(
func_name
,
eps
=
1e-7
):
pos
=
ift
.
from_random
(
ift
.
UnstructuredDomain
(
10
))
var0
=
ift
.
Linearization
.
make_var
(
pos
)
var1
=
ift
.
Linearization
.
make_var
(
pos
+
eps
)
df0
=
(
var1
.
ptw
(
func_name
).
val
-
var0
.
ptw
(
func_name
).
val
)
/
eps
df1
=
var0
.
ptw
(
func_name
).
jac
(
ift
.
full
(
lin
.
domain
,
1.
))
# rtol depends on how nonlinear the function is
ift
.
extra
.
assert_allclose
(
df0
,
df1
,
rtol
=
100
*
eps
)
check
(
"myptw"
)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment