Commit b7246142 by Lukas Platz

fix typo

parent e5a47632
Pipeline #106239 passed with stages
in 21 minutes and 1 second
 %% Cell type:markdown id:indonesian-dayton tags: # Custom nonlinearities %% Cell type:code id:duplicate-fitting tags: ``` python import numpy as np import nifty8 as ift ``` %% Cell type:markdown id:democratic-lancaster tags: In NIFTy, users can add hand-crafted point-wise nonlinearities that are then available for `Field`, `MultiField`, `Linearization` and `Operator`. This guide illustrates how this is done. Suppose that we would like to use the point-wise function f(x) = x*exp(x) in an operator chain. This function is called "myptw" in the following. We introduce this function to NIFTy by implementing two functions. First, one that takes a `numpy.ndarray` as an input, applies the point-wise mapping and returns the result as a `numpy.ndarray` of the same shape. Second, a function that takes a `numpy.ndarray` as an input and returns two `numpy.ndarray`s: the application of the nonlinearity (same as before) and the derivative. %% Cell type:code id:modern-spouse tags: ``` python def func(x): return x*np.exp(x) def func_and_derv(x): expx = np.exp(x) return x*expx, (1+x)*expx ``` %% Cell type:markdown id:shared-deficit tags: These two functions are then added to the NIFTy-internal dictionary that contains all implemented point-wise nonlinearities. %% Cell type:code id:published-start tags: ``` python ift.pointwise.ptw_dict["myptw"] = func, func_and_derv ``` %% Cell type:markdown id:living-surrey tags: This allows us to apply this non-linearity on `Field`s, ... %% Cell type:code id:incident-biotechnology tags: ``` python dom = ift.UnstructuredDomain(10) fld = ift.from_random(dom) fld = ift.full(dom, 2.) a = fld.ptw("myptw") b = ift.makeField(dom, func(fld.val)) ift.extra.assert_allclose(a, b) ``` %% Cell type:markdown id:palestinian-librarian tags: `MultiField`s, ... %% Cell type:code id:naval-nightmare tags: ``` python mdom = ift.makeDomain({"bar": ift.UnstructuredDomain(10)}) mfld = ift.from_random(mdom) a = mfld.ptw("myptw") b = ift.makeField(mdom, {"bar": func(mfld["bar"].val)}) ift.extra.assert_allclose(a, b) ``` %% Cell type:markdown id:legendary-oriental tags: `Linearization`s (including the Jacobian), ... %% Cell type:code id:native-breeding tags: ``` python lin = ift.Linearization.make_var(fld) a = lin.ptw("myptw").val b = ift.makeField(dom, func(fld.val)) ift.extra.assert_allclose(a, b) ``` %% Cell type:code id:crude-motorcycle tags: ``` python op_a = lin.ptw("myptw").jac op_b = ift.makeOp(ift.makeField(dom, func_and_derv(fld.val)[1])) testing_vector = ift.from_random(dom) ift.extra.assert_allclose(op_a(testing_vector), op_b(testing_vector)) ``` %% Cell type:markdown id:outdoor-juice tags: and `Operator`s. %% Cell type:code id:retained-closer tags: ``` python op = ift.FieldAdapter(dom, "foo").ptw("myptw") ``` %% Cell type:markdown id:accessory-pepper tags: Please remember to always check that the gradient has been implemented correctly by comparint it to an approximation to the gradient by finite differences. Please remember to always check that the gradient has been implemented correctly by comparing it to an approximation to the gradient by finite differences. %% Cell type:code id:close-bonus tags: ``` python def check(func_name, eps=1e-7): pos = ift.from_random(ift.UnstructuredDomain(10)) var0 = ift.Linearization.make_var(pos) var1 = ift.Linearization.make_var(pos+eps) df0 = (var1.ptw(func_name).val - var0.ptw(func_name).val)/eps df1 = var0.ptw(func_name).jac(ift.full(lin.domain, 1.)) # rtol depends on how nonlinear the function is ift.extra.assert_allclose(df0, df1, rtol=100*eps) check("myptw") ``` %% Cell type:code id:comfortable-kitty tags: ``` python ``` ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!