Commit b7246142 authored by Lukas Platz's avatar Lukas Platz
Browse files

fix typo

parent e5a47632
Pipeline #106239 passed with stages
in 21 minutes and 1 second
%% Cell type:markdown id:indonesian-dayton tags:
# Custom nonlinearities
%% Cell type:code id:duplicate-fitting tags:
``` python
import numpy as np
import nifty8 as ift
```
%% Cell type:markdown id:democratic-lancaster tags:
In NIFTy, users can add hand-crafted point-wise nonlinearities that are then available for `Field`, `MultiField`, `Linearization` and `Operator`. This guide illustrates how this is done.
Suppose that we would like to use the point-wise function f(x) = x*exp(x) in an operator chain. This function is called "myptw" in the following. We introduce this function to NIFTy by implementing two functions.
First, one that takes a `numpy.ndarray` as an input, applies the point-wise mapping and returns the result as a `numpy.ndarray` of the same shape. Second, a function that takes a `numpy.ndarray` as an input and returns two `numpy.ndarray`s: the application of the nonlinearity (same as before) and the derivative.
%% Cell type:code id:modern-spouse tags:
``` python
def func(x):
return x*np.exp(x)
def func_and_derv(x):
expx = np.exp(x)
return x*expx, (1+x)*expx
```
%% Cell type:markdown id:shared-deficit tags:
These two functions are then added to the NIFTy-internal dictionary that contains all implemented point-wise nonlinearities.
%% Cell type:code id:published-start tags:
``` python
ift.pointwise.ptw_dict["myptw"] = func, func_and_derv
```
%% Cell type:markdown id:living-surrey tags:
This allows us to apply this non-linearity on `Field`s, ...
%% Cell type:code id:incident-biotechnology tags:
``` python
dom = ift.UnstructuredDomain(10)
fld = ift.from_random(dom)
fld = ift.full(dom, 2.)
a = fld.ptw("myptw")
b = ift.makeField(dom, func(fld.val))
ift.extra.assert_allclose(a, b)
```
%% Cell type:markdown id:palestinian-librarian tags:
`MultiField`s, ...
%% Cell type:code id:naval-nightmare tags:
``` python
mdom = ift.makeDomain({"bar": ift.UnstructuredDomain(10)})
mfld = ift.from_random(mdom)
a = mfld.ptw("myptw")
b = ift.makeField(mdom, {"bar": func(mfld["bar"].val)})
ift.extra.assert_allclose(a, b)
```
%% Cell type:markdown id:legendary-oriental tags:
`Linearization`s (including the Jacobian), ...
%% Cell type:code id:native-breeding tags:
``` python
lin = ift.Linearization.make_var(fld)
a = lin.ptw("myptw").val
b = ift.makeField(dom, func(fld.val))
ift.extra.assert_allclose(a, b)
```
%% Cell type:code id:crude-motorcycle tags:
``` python
op_a = lin.ptw("myptw").jac
op_b = ift.makeOp(ift.makeField(dom, func_and_derv(fld.val)[1]))
testing_vector = ift.from_random(dom)
ift.extra.assert_allclose(op_a(testing_vector),
op_b(testing_vector))
```
%% Cell type:markdown id:outdoor-juice tags:
and `Operator`s.
%% Cell type:code id:retained-closer tags:
``` python
op = ift.FieldAdapter(dom, "foo").ptw("myptw")
```
%% Cell type:markdown id:accessory-pepper tags:
Please remember to always check that the gradient has been implemented correctly by comparint it to an approximation to the gradient by finite differences.
Please remember to always check that the gradient has been implemented correctly by comparing it to an approximation to the gradient by finite differences.
%% Cell type:code id:close-bonus tags:
``` python
def check(func_name, eps=1e-7):
pos = ift.from_random(ift.UnstructuredDomain(10))
var0 = ift.Linearization.make_var(pos)
var1 = ift.Linearization.make_var(pos+eps)
df0 = (var1.ptw(func_name).val - var0.ptw(func_name).val)/eps
df1 = var0.ptw(func_name).jac(ift.full(lin.domain, 1.))
# rtol depends on how nonlinear the function is
ift.extra.assert_allclose(df0, df1, rtol=100*eps)
check("myptw")
```
%% Cell type:code id:comfortable-kitty tags:
``` python
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment