Skip to content
Snippets Groups Projects
Commit ae950a6e authored by Philipp Frank's avatar Philipp Frank
Browse files

add nifty jacobian demo

parent ab243150
No related branches found
No related tags found
1 merge request!8Add a nifty jacobian demo
Pipeline #150163 passed
import nifty8 as ift
from nifty8.sugar import is_linearization
# We want to build a custom model that evaluates the function:
# f(x,y) = exp(x) * y + 3
# and furthermore allows this model to be used in a fit (i.E. the model
# should also know about its jacobian). To do so, the model should be able to
# handle inputs being either Fields or Linearizations, with the latter
# consisting of the input Field and its Jacobian, evaluated at the input Field
# value. The Jacobian is represented as a LinearOperator which stores how the
# Jacobian (and its adjoint) can be applied to vectors (Fields) rather then
# storing the Jacobian as an explicit matrix (which is prohibitively large for
# bigger applications).
# There are as of now three different ways to achieve this on different levels:
# 1) The Field level, 2) The Linearization level, 3) The Operator level.
# The three ways will be constructed in the following.
# Defining domains for the problem:
# The inputs x,y should live on a scalar domain
scalar_domain = ift.DomainTuple.scalar_domain()
# The full input domain (of (x,y) together) should be a `MultiDomain` containing
# both inputs and assigns the keys 'x' and 'y' to the inputs to identify them.
full_domain = ift.MultiDomain.make({'x':scalar_domain, 'y':scalar_domain})
# 1) The Field level (most general, least fail safe)
# In this most basic form the user defines by hand what should happen in case
# the input is a Field or a Linearization:
# Define Operator
class MyCustomFieldModel(ift.Operator):
def __init__(self, domain, keys = ('x', 'y')):
self._domain = ift.makeDomain(domain)
if not isinstance(self._domain, ift.MultiDomain):
raise ValueError
self._keys = keys
for k in self._keys:
if k not in self._domain.keys():
raise ValueError
if len(self._keys) != 2:
raise ValueError
self._target = ift.makeDomain(domain[self._keys[0]])
def apply(self, x):
self._check_input(x)
lin = is_linearization(x) # checks whether x is a linearization or not
v = x.val if lin else x # if x is lin the Field values are stored in x.val
res = ift.exp(v[self._keys[0]]) * v[self._keys[1]] + 3.
if not lin: # in case the input is a Field simply return the result
return res
# in case x is a Linearization, we also have to provide the Jacobian
# and pass it along. (see below)
jac = _MyCustomJacobian(v, self._keys)
# in this case apply also returns a Linearization.
return x.new(res, jac) # creates a new Lin of the correct form (from x).
# Definition of the Jacobian
class _MyCustomJacobian(ift.LinearOperator):
def __init__(self, location, keys):
self._domain = ift.makeDomain(location.domain)
self._keys = keys
self._target = ift.makeDomain(self._domain[self._keys[0]])
self._capability = self.TIMES | self.ADJOINT_TIMES
# We can precompute the two derivatives required to build the jacobian
self._jy = ift.exp(location[self._keys[0]]) # df/dy = exp(x)
self._jx = self._jy * location[self._keys[1]] # df/dx = exp(x) * y
def apply(self, inp, mode):
self._check_input(inp, mode)
if mode == self.TIMES:
# This part defines the application of the jacobian matrix to an
# arbitraty input vector inp
return self._jx * inp[self._keys[0]] + self._jy * inp[self._keys[1]]
else:
# Here the adjoint application is defined, the input in this case is
# only a scalar and gets weighted with the entries of the jacobian
res = {self._keys[0]: self._jx * inp, self._keys[1]: self._jy * inp}
return ift.MultiField.from_dict(res, self._domain)
# 2) The Linearization level (very general, more fail safe)
# The behaviour of many basic operations (addition, mutiplication, pointwise
# non-linear functions), when facing Fields or Lineariations have already been
# implemented in nifty. Therefore in many cases one can "do calculations" with
# Linearizations analogous to Fields and the Jacobians get constructed from the
# Jacobians of the promitive operations automatically, using the chain rule:
# Define Operator
class MyCustomLinModel(ift.Operator):
def __init__(self, domain, keys = ('x', 'y')):
self._domain = ift.makeDomain(domain)
if not isinstance(self._domain, ift.MultiDomain):
raise ValueError
self._keys = keys
for k in self._keys:
if k not in self._domain.keys():
raise ValueError
if len(self._keys) != 2:
raise ValueError
self._target = ift.makeDomain(domain[self._keys[0]])
def apply(self, x):
self._check_input(x)
# All operations needed here (input selection, exp, multiplication,
# scalar addition) are defined for Fields as well as Linearizations
# and therefore the same code can be used in both cases.
return ift.exp(x[self._keys[0]]) * x[self._keys[1]] + 3.
# 2) The Operator level (a bit less general, most fail safe)
# Similarly to Fields and Linearozations, also the basic operations are
# implemented for Operators. Given for example two operators op1, op2 we can
# multiply them together to create a new op3 = op1 * op2. The logic here is:
# "take the outputs (results) of op1 and op2 and multiply them together". The
# new operator op3 gets the union of the input domains of op1 and op2 as its
# domain. The target of op3 is the target of op1 and op2, so only if op1 and op2
# share the same target space, pointwise multiplication is possible (no
# automatic broadcasting!). Nifty, however, provides many broadcasting
# operations the user can invoke to help build more general combinations (see
# e.g. `ContractionOperator`)
# a placeholder operator that takes 'x' from the input and passes it along.
x = ift.FieldAdapter(scalar_domain, 'x')
y = ift.FieldAdapter(scalar_domain, 'y')
# First part of the model. Note that here 'calculations' are performed on
# operators to create a new operator (i.E. no inputs are provided yet).
model = ift.exp(x) * y
# As "+" is already reserved for adding the output of two operators together
# simple scalar addition has to be performed via a designated operator 'Adder'.
# Also, on operator level, the combination of operations has to be performed on
# the same space so the number '3' is cast into a Field on the 'scalar_domain`
# to add it to the output of model. Finally sequential operator application is
# given via the matrix multiplication `@`.
MyCustomModel = ift.Adder(ift.full(scalar_domain, 3.)) @ model
# We can verify that the three operators yield the same results on some input.
inp = ift.from_random(MyCustomModel.domain)
MyFieldModel = MyCustomFieldModel(full_domain)
MyLinModel = MyCustomLinModel(full_domain)
res1 = MyFieldModel(inp)
res2 = MyLinModel(inp)
res3 = MyCustomModel(inp)
ift.extra.assert_allclose(res1, res2)
ift.extra.assert_allclose(res1, res3)
# We can also build a Linearization and apply the models
lin = ift.Linearization.make_var(inp)
res1 = MyFieldModel(lin)
res2 = MyLinModel(lin)
res3 = MyCustomModel(lin)
# The resulting Field values (accessible via .val) still match.
ift.extra.assert_allclose(res1.val, res2.val)
ift.extra.assert_allclose(res1.val, res3.val)
# Looking at the Jacobians evaluated at inp (accessible via .jac), we notice the
# differences in implementation of the three approaches.
print(res1.jac)
print("------")
print(res2.jac)
print("------")
print(res3.jac)
# Applying these jacobians (and their adjoint) to random inputs, however,
# gives the same results as they all apply the same matrix mathematically.
jac_inp = ift.from_random(res1.jac.domain)
ift.extra.assert_allclose(res1.jac(jac_inp), res2.jac(jac_inp))
ift.extra.assert_allclose(res1.jac(jac_inp), res3.jac(jac_inp))
# same for the adjoint application
jac_inp = ift.from_random(res1.jac.target)
ift.extra.assert_allclose(res1.jac.adjoint(jac_inp), res2.jac.adjoint(jac_inp))
ift.extra.assert_allclose(res1.jac.adjoint(jac_inp), res3.jac.adjoint(jac_inp))
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment