Commit a81e79a3 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge remote-tracking branch 'origin/NIFTy_7' into more_samplers

parents b187c930 c76f44f3
Pipeline #103376 passed with stages
in 16 minutes and 33 seconds
......@@ -152,3 +152,8 @@ run_meanfield:
stage: demo_runs
script:
- python3 demos/parametric_variational_inference.py
run_nonlinearity_guide:
stage: demo_runs
script:
- python3 demos/custom_nonlinearities.py
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
import nifty7 as ift
import numpy as np
# In NIFTy, users can add hand-crafted point-wise nonlinearities that are then
# available for `Field`, `MultiField`, `Linearization` and `Operator`. This
# guide illustrates how this is done.
# Suppose that we would like to use the point-wise function f(x) = x*exp(x) in
# an operator chain. This function is called "myptw" in the following. We
# introduce this function to NIFTy by implementing two functions.
# First, one that takes a `numpy.ndarray` as an input, applies the point-wise
# mapping and returns the result as a `numpy.ndarray` of the same shape.
# Second, a function that takes a `numpy.ndarray` as an input and returns two
# `numpy.ndarray`s: the application of the nonlinearity (same as before) and
# the derivative.
def func(x):
return x*np.exp(x)
def func_and_derv(x):
expx = np.exp(x)
return x*expx, (1+x)*expx
# These two functions are then added to the NIFTy-internal dictionary that
# contains all implemented point-wise nonlinearities.
ift.pointwise.ptw_dict["myptw"] = func, func_and_derv
# This allows us to apply this non-linearity on `Field`s, ...
dom = ift.UnstructuredDomain(10)
fld = ift.from_random(dom)
fld = ift.full(dom, 2.)
a = fld.ptw("myptw")
b = ift.makeField(dom, func(fld.val))
ift.extra.assert_allclose(a, b)
# `MultiField`s, ...
mdom = ift.makeDomain({"bar": ift.UnstructuredDomain(10)})
mfld = ift.from_random(mdom)
a = mfld.ptw("myptw")
b = ift.makeField(mdom, {"bar": func(mfld["bar"].val)})
ift.extra.assert_allclose(a, b)
# Linearizations (including the Jacobian), ...
# (Value)
lin = ift.Linearization.make_var(fld)
a = lin.ptw("myptw").val
b = ift.makeField(dom, func(fld.val))
ift.extra.assert_allclose(a, b)
# (Jacobian)
op_a = lin.ptw("myptw").jac
op_b = ift.makeOp(ift.makeField(dom, func_and_derv(fld.val)[1]))
testing_vector = ift.from_random(dom)
ift.extra.assert_allclose(op_a(testing_vector),
op_b(testing_vector))
# and `Operator`s.
op = ift.FieldAdapter(dom, "foo").ptw("myptw")
# We check that the gradient has been implemented correctly by comparing it to
# an approximation to the gradient by finite differences.
def check(func_name, eps=1e-7):
pos = ift.from_random(ift.UnstructuredDomain(10))
var0 = ift.Linearization.make_var(pos)
var1 = ift.Linearization.make_var(pos+eps)
df0 = (var1.ptw(func_name).val - var0.ptw(func_name).val)/eps
df1 = var0.ptw(func_name).jac(ift.full(lin.domain, 1.))
# rtol depends on how nonlinear the function is
ift.extra.assert_allclose(df0, df1, rtol=100*eps)
check("myptw")
NIFTy -- Numerical Information Field Theory
===========================================
**NIFTy** [1]_, [2]_, "\ **N**\umerical **I**\nformation **F**\ield **T**\heor\ **y**\ ", is a versatile library designed to enable the development of signal inference algorithms that are independent of the underlying grids (spatial, spectral, temporal, …) and their resolutions.
**NIFTy** [1]_ [2]_ [3]_, "\ **N**\umerical **I**\nformation **F**\ield **T**\heor\ **y**\ ", is a versatile library designed to enable the development of signal inference algorithms that are independent of the underlying grids (spatial, spectral, temporal, …) and their resolutions.
Its object-oriented framework is written in Python, although it accesses libraries written in C++ and C for efficiency.
NIFTy offers a toolkit that abstracts discretized representations of continuous spaces, fields in these spaces, and operators acting on these fields into classes.
......
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -47,9 +47,9 @@ class RGSpace(StructuredDomain):
Topologically, a n-dimensional RGSpace is a n-Torus, i.e. it has periodic
boundary conditions.
"""
_needed_for_hash = ["_distances", "_shape", "_harmonic"]
_needed_for_hash = ["_rdistances", "_shape", "_harmonic"]
def __init__(self, shape, distances=None, harmonic=False):
def __init__(self, shape, distances=None, harmonic=False, _realdistances=None):
self._harmonic = bool(harmonic)
if np.isscalar(shape):
shape = (shape,)
......@@ -57,21 +57,29 @@ class RGSpace(StructuredDomain):
if min(self._shape) < 0:
raise ValueError('Negative number of pixels encountered')
if distances is None:
if self.harmonic:
self._distances = (1.,) * len(self._shape)
else:
self._distances = tuple(1./s for s in self._shape)
elif np.isscalar(distances):
self._distances = (float(distances),) * len(self._shape)
if _realdistances is not None:
self._rdistances = _realdistances
else:
temp = np.empty(len(self.shape), dtype=np.float64)
temp[:] = distances
self._distances = tuple(temp)
if min(self._distances) <= 0:
if distances is None:
self._rdistances = tuple(1. / (np.array(self._shape)))
elif np.isscalar(distances):
if self.harmonic:
self._rdistances = tuple(
1. / (np.array(self._shape) * float(distances)))
else:
self._rdistances = (float(distances),) * len(self._shape)
else:
temp = np.empty(len(self.shape), dtype=np.float64)
temp[:] = distances
if self._harmonic:
temp = 1. / (np.array(self._shape) * temp)
self._rdistances = tuple(temp)
self._hdistances = tuple(
1. / (np.array(self.shape)*np.array(self._rdistances)))
if min(self._rdistances) <= 0:
raise ValueError('Non-positive distances encountered')
self._dvol = float(reduce(lambda x, y: x*y, self._distances))
self._dvol = float(reduce(lambda x, y: x*y, self.distances))
self._size = int(reduce(lambda x, y: x*y, self._shape))
def __repr__(self):
......@@ -181,8 +189,7 @@ class RGSpace(StructuredDomain):
RGSpace
The partner domain
"""
distances = 1. / (np.array(self.shape)*np.array(self.distances))
return RGSpace(self.shape, distances, not self.harmonic)
return RGSpace(self.shape, None, not self.harmonic, self._rdistances)
def check_codomain(self, codomain):
"""Raises `TypeError` if `codomain` is not a matching partner domain
......@@ -212,4 +219,4 @@ class RGSpace(StructuredDomain):
The n-th entry of the tuple is the distance between neighboring
grid points along the n-th dimension.
"""
return self._distances
return self._hdistances if self._harmonic else self._rdistances
......@@ -73,6 +73,10 @@ def check_linear_operator(op, domain_dtype=np.float64, target_dtype=np.float64,
_domain_check_linear(op.adjoint, target_dtype)
_domain_check_linear(op.inverse, target_dtype)
_domain_check_linear(op.adjoint.inverse, domain_dtype)
_purity_check(op, from_random(op.domain, dtype=domain_dtype))
_purity_check(op.adjoint.inverse, from_random(op.domain, dtype=domain_dtype))
_purity_check(op.adjoint, from_random(op.target, dtype=target_dtype))
_purity_check(op.inverse, from_random(op.target, dtype=target_dtype))
_check_linearity(op, domain_dtype, atol, rtol)
_check_linearity(op.adjoint, target_dtype, atol, rtol)
_check_linearity(op.inverse, target_dtype, atol, rtol)
......@@ -120,6 +124,7 @@ def check_operator(op, loc, tol=1e-12, ntries=100, perf_check=True,
if not isinstance(op, Operator):
raise TypeError('This test tests only (nonlinear) operators.')
_domain_check_nonlinear(op, loc)
_purity_check(op, loc)
_performance_check(op, loc, bool(perf_check))
_linearization_value_consistency(op, loc)
_jac_vs_finite_differences(op, loc, np.sqrt(tol), ntries,
......@@ -270,7 +275,7 @@ def _performance_check(op, pos, raise_on_fail):
myop = op @ cop
myop(pos)
cond = [cop.count != 1]
lin = myop(2*Linearization.make_var(pos, wm))
lin = myop(Linearization.make_var(pos, wm))
cond.append(cop.count != 2)
lin.jac(pos)
cond.append(cop.count != 3)
......@@ -288,6 +293,14 @@ def _performance_check(op, pos, raise_on_fail):
raise RuntimeError(s)
def _purity_check(op, pos):
if isinstance(op, LinearOperator) and (op.capability & op.TIMES) != op.TIMES:
return
res0 = op(pos)
res1 = op(pos)
assert_equal(res0, res1)
def _get_acceptable_location(op, loc, lin):
if not np.isfinite(lin.val.s_sum()):
raise ValueError('Initial value must be finite')
......
......@@ -227,9 +227,9 @@ class _GeoMetricSampler:
# Check domain dtype
dts = H._prior._met._dtype
if isinstance(H.domain, DomainTuple):
real = np.issubdtype(dts, np.float)
real = np.issubdtype(dts, np.floating)
else:
real = all([np.issubdtype(dts[kk], np.float) for kk in dts.keys()])
real = all([np.issubdtype(dts[kk], np.floating) for kk in dts.keys()])
if not real:
raise ValueError("_GeoMetricSampler only supports real valued latent DOFs.")
# /Check domain dtype
......
......@@ -602,8 +602,8 @@ def is_operator(obj):
Note
----
A simple `isinstance(obj, ift.Operator)` does give the expected
result because, e.g., :class:`~nifty7.field.Field` inherits from
A simple `isinstance(obj, ift.Operator)` does not give the expected result
because, e.g., :class:`~nifty7.field.Field` inherits from
:class:`~nifty7.operators.operator.Operator`.
"""
return isinstance(obj, Operator) and obj.val is None
......@@ -619,10 +619,10 @@ def is_fieldlike(obj):
Note
----
A simple `isinstance(obj, ift.Field)` does give the expected
result because users might have implemented another class which
behaves field-like but is not an instance of
:class:`~nifty7.field.Field`. Also not that instances of
:class:`~nifty7.linearization.Linearization` behave field-like.
A simple `isinstance(obj, ift.Field)` does not give the expected result
because users might have implemented another class which behaves field-like
but is not an instance of :class:`~nifty7.field.Field`. Also note that
instances of :class:`~nifty7.linearization.Linearization` behave
field-like.
"""
return isinstance(obj, Operator) and obj.val is not None
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2021 Max-Planck-Society
# Author: Philipp Arras
import numpy as np
import pytest
import nifty7 as ift
from time import time
from .common import list2fixture, setup_function, teardown_function
pmp = pytest.mark.parametrize
class NonPureOperator(ift.Operator):
def __init__(self, domain):
self._domain = self._target = ift.makeDomain(domain)
def apply(self, x):
self._check_input(x)
return x*time()
class NonPureLinearOperator(ift.LinearOperator):
def __init__(self, domain, cap):
self._domain = self._target = ift.makeDomain(domain)
self._capability = cap
def apply(self, x, mode):
self._check_input(x, mode)
return x*time()
@pmp("cap", [ift.LinearOperator.ADJOINT_TIMES,
ift.LinearOperator.INVERSE_TIMES | ift.LinearOperator.TIMES])
@pmp("ddtype", [np.float64, np.complex128])
@pmp("tdtype", [np.float64, np.complex128])
def test_purity_check_linear(cap, ddtype, tdtype):
dom = ift.RGSpace(2)
op = NonPureLinearOperator(dom, cap)
with pytest.raises(AssertionError):
ift.extra.check_linear_operator(op, ddtype, tdtype)
@pmp("dtype", [np.float64, np.complex128])
def test_purity_check(dtype):
dom = ift.RGSpace(2)
op = NonPureOperator(dom)
with pytest.raises(AssertionError):
ift.extra.check_operator(op, dtype)
......@@ -11,7 +11,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2019 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
......@@ -98,3 +98,9 @@ def test_k_length_array(shape, distances, expected):
def test_dvol(shape, distances, harmonic, power):
r = ift.RGSpace(shape=shape, distances=distances, harmonic=harmonic)
assert_allclose(r.dvol, np.prod(r.distances)**power)
def test_codomain():
for i in range(1, 1000):
r = ift.RGSpace(shape=(i,), distances=(1.,), harmonic=False)
assert_equal(r.get_default_codomain().get_default_codomain(), r)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment