Commit 22c7de10 authored by Philipp Arras's avatar Philipp Arras
Browse files

WienerFilterCurvature: Interface change

This interface change is necessary since the sampling dtype cannot
easily be set by a single keyword for complicated operators.
parent 44ee2914
Pipeline #107724 passed with stages
in 21 minutes and 26 seconds
Changes since NIFTy 7
=====================
WienerFilterCurvature interface change
--------------------------------------
`ift.WienerFilterCurvature` does not expect sampling dtypes for the likelihood
and the prior anymore. These have to be set with an `ift.SamplingDtypeSetter`
beforehand.
Minisanity
----------
......
......@@ -166,6 +166,8 @@
"def Curvature(R, N, Sh):\n",
" IC = ift.GradientNormController(iteration_limit=50000,\n",
" tol_abs_gradnorm=0.1)\n",
" N = ift.SamplingDtypeSetter(N, np.float64)\n",
" Sh = ift.SamplingDtypeSetter(Sh, np.float64)\n",
" # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n",
" # helper methods.\n",
" return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)"
......
......@@ -11,21 +11,18 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2020 Max-Planck-Society
# Copyright(C) 2013-2021 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.inversion_enabler import InversionEnabler
from ..operators.sampling_enabler import SamplingDtypeSetter, SamplingEnabler
from ..operators.sampling_enabler import SamplingEnabler
from ..operators.sandwich_operator import SandwichOperator
def WienerFilterCurvature(R, N, S, iteration_controller=None,
iteration_controller_sampling=None,
data_sampling_dtype=np.float64,
prior_sampling_dtype=np.float64):
iteration_controller_sampling=None):
"""The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the
......@@ -39,32 +36,27 @@ def WienerFilterCurvature(R, N, S, iteration_controller=None,
The response operator of the Wiener filter measurement.
N : EndomorphicOperator
The noise covariance.
S : DiagonalOperator
The prior signal covariance
S : EndomorphicOperator
The prior signal covariance.
iteration_controller : IterationController
The iteration controller to use during numerical inversion via
ConjugateGradient.
iteration_controller_sampling : IterationController
The iteration controller to use for sampling.
data_sampling_dtype : numpy.dtype or dict of numpy.dtype
Data type used for sampling from likelihood. Conincides with the data
type of the data used in the inference problem. Default is float64.
prior_sampling_dtype : numpy.dtype or dict of numpy.dtype
Data type used for sampling from likelihood. Coincides with the data
type of the parameters of the forward model used for the inference
problem. Default is float64.
Note
----
If samples shall be drawn from this operator, `N` and `S` have to implement
`draw_sample()`.
"""
Ninv = N.inverse
if not isinstance(N, EndomorphicOperator):
raise TypeError
if not isinstance(S, EndomorphicOperator):
raise TypeError
M = SandwichOperator.make(R, N.inverse)
Sinv = S.inverse
if data_sampling_dtype is not None:
Ninv = SamplingDtypeSetter(Ninv, data_sampling_dtype)
if prior_sampling_dtype is not None:
Sinv = SamplingDtypeSetter(Sinv, data_sampling_dtype)
M = SandwichOperator.make(R, Ninv)
if iteration_controller_sampling is not None:
op = SamplingEnabler(M, Sinv, iteration_controller_sampling,
Sinv)
op = SamplingEnabler(M, Sinv, iteration_controller_sampling, Sinv)
else:
op = M + Sinv
op = InversionEnabler(op, iteration_controller, Sinv)
return op
return InversionEnabler(op, iteration_controller, Sinv)
......@@ -79,16 +79,14 @@ def test_WF_curvature(space):
required_result = ift.full(space, 1.)
s = ift.Field.from_random(domain=space, random_type='uniform') + 0.5
S = ift.DiagonalOperator(s)
S = ift.SamplingDtypeSetter(ift.DiagonalOperator(s), np.float64)
r = ift.Field.from_random(domain=space, random_type='uniform')
R = ift.DiagonalOperator(r)
n = ift.Field.from_random(domain=space, random_type='uniform') + 0.5
N = ift.DiagonalOperator(n)
N = ift.SamplingDtypeSetter(ift.DiagonalOperator(n), np.float64)
all_diag = 1./s + r**2/n
curv = ift.WienerFilterCurvature(R, N, S, iteration_controller=IC,
iteration_controller_sampling=IC,
data_sampling_dtype=np.float64,
prior_sampling_dtype=np.float64)
iteration_controller_sampling=IC)
m = curv.inverse(required_result)
assert_allclose(
m.val,
......@@ -101,13 +99,11 @@ def test_WF_curvature(space):
if len(space.shape) == 1:
R = ift.ValueInserter(space, [0])
n = ift.from_random(R.domain, 'uniform') + 0.5
N = ift.DiagonalOperator(n)
N = ift.SamplingDtypeSetter(ift.DiagonalOperator(n), np.float64)
all_diag = 1./s + R(1/n)
curv = ift.WienerFilterCurvature(R.adjoint, N, S,
iteration_controller=IC,
iteration_controller_sampling=IC,
data_sampling_dtype=np.float64,
prior_sampling_dtype=np.float64)
iteration_controller_sampling=IC)
m = curv.inverse(required_result)
assert_allclose(
m.val,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment