Commit 22c7de10 authored by Philipp Arras's avatar Philipp Arras
Browse files

WienerFilterCurvature: Interface change

This interface change is necessary since the sampling dtype cannot
easily be set by a single keyword for complicated operators.
parent 44ee2914
Pipeline #107724 passed with stages
in 21 minutes and 26 seconds
Changes since NIFTy 7 Changes since NIFTy 7
===================== =====================
WienerFilterCurvature interface change
--------------------------------------
`ift.WienerFilterCurvature` does not expect sampling dtypes for the likelihood
and the prior anymore. These have to be set with an `ift.SamplingDtypeSetter`
beforehand.
Minisanity Minisanity
---------- ----------
......
...@@ -92,10 +92,12 @@ ...@@ -92,10 +92,12 @@
``` python ``` python
def Curvature(R, N, Sh): def Curvature(R, N, Sh):
IC = ift.GradientNormController(iteration_limit=50000, IC = ift.GradientNormController(iteration_limit=50000,
tol_abs_gradnorm=0.1) tol_abs_gradnorm=0.1)
N = ift.SamplingDtypeSetter(N, np.float64)
Sh = ift.SamplingDtypeSetter(Sh, np.float64)
# WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy
# helper methods. # helper methods.
return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC) return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)
``` ```
......
...@@ -11,21 +11,18 @@ ...@@ -11,21 +11,18 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
# Copyright(C) 2013-2020 Max-Planck-Society # Copyright(C) 2013-2021 Max-Planck-Society
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
import numpy as np from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.inversion_enabler import InversionEnabler from ..operators.inversion_enabler import InversionEnabler
from ..operators.sampling_enabler import SamplingDtypeSetter, SamplingEnabler from ..operators.sampling_enabler import SamplingEnabler
from ..operators.sandwich_operator import SandwichOperator from ..operators.sandwich_operator import SandwichOperator
def WienerFilterCurvature(R, N, S, iteration_controller=None, def WienerFilterCurvature(R, N, S, iteration_controller=None,
iteration_controller_sampling=None, iteration_controller_sampling=None):
data_sampling_dtype=np.float64,
prior_sampling_dtype=np.float64):
"""The curvature of the WienerFilterEnergy. """The curvature of the WienerFilterEnergy.
This operator implements the second derivative of the This operator implements the second derivative of the
...@@ -39,32 +36,27 @@ def WienerFilterCurvature(R, N, S, iteration_controller=None, ...@@ -39,32 +36,27 @@ def WienerFilterCurvature(R, N, S, iteration_controller=None,
The response operator of the Wiener filter measurement. The response operator of the Wiener filter measurement.
N : EndomorphicOperator N : EndomorphicOperator
The noise covariance. The noise covariance.
S : DiagonalOperator S : EndomorphicOperator
The prior signal covariance The prior signal covariance.
iteration_controller : IterationController iteration_controller : IterationController
The iteration controller to use during numerical inversion via The iteration controller to use during numerical inversion via
ConjugateGradient. ConjugateGradient.
iteration_controller_sampling : IterationController iteration_controller_sampling : IterationController
The iteration controller to use for sampling. The iteration controller to use for sampling.
data_sampling_dtype : numpy.dtype or dict of numpy.dtype
Data type used for sampling from likelihood. Conincides with the data Note
type of the data used in the inference problem. Default is float64. ----
prior_sampling_dtype : numpy.dtype or dict of numpy.dtype If samples shall be drawn from this operator, `N` and `S` have to implement
Data type used for sampling from likelihood. Coincides with the data `draw_sample()`.
type of the parameters of the forward model used for the inference
problem. Default is float64.
""" """
Ninv = N.inverse if not isinstance(N, EndomorphicOperator):
raise TypeError
if not isinstance(S, EndomorphicOperator):
raise TypeError
M = SandwichOperator.make(R, N.inverse)
Sinv = S.inverse Sinv = S.inverse
if data_sampling_dtype is not None:
Ninv = SamplingDtypeSetter(Ninv, data_sampling_dtype)
if prior_sampling_dtype is not None:
Sinv = SamplingDtypeSetter(Sinv, data_sampling_dtype)
M = SandwichOperator.make(R, Ninv)
if iteration_controller_sampling is not None: if iteration_controller_sampling is not None:
op = SamplingEnabler(M, Sinv, iteration_controller_sampling, op = SamplingEnabler(M, Sinv, iteration_controller_sampling, Sinv)
Sinv)
else: else:
op = M + Sinv op = M + Sinv
op = InversionEnabler(op, iteration_controller, Sinv) return InversionEnabler(op, iteration_controller, Sinv)
return op
...@@ -79,16 +79,14 @@ def test_WF_curvature(space): ...@@ -79,16 +79,14 @@ def test_WF_curvature(space):
required_result = ift.full(space, 1.) required_result = ift.full(space, 1.)
s = ift.Field.from_random(domain=space, random_type='uniform') + 0.5 s = ift.Field.from_random(domain=space, random_type='uniform') + 0.5
S = ift.DiagonalOperator(s) S = ift.SamplingDtypeSetter(ift.DiagonalOperator(s), np.float64)
r = ift.Field.from_random(domain=space, random_type='uniform') r = ift.Field.from_random(domain=space, random_type='uniform')
R = ift.DiagonalOperator(r) R = ift.DiagonalOperator(r)
n = ift.Field.from_random(domain=space, random_type='uniform') + 0.5 n = ift.Field.from_random(domain=space, random_type='uniform') + 0.5
N = ift.DiagonalOperator(n) N = ift.SamplingDtypeSetter(ift.DiagonalOperator(n), np.float64)
all_diag = 1./s + r**2/n all_diag = 1./s + r**2/n
curv = ift.WienerFilterCurvature(R, N, S, iteration_controller=IC, curv = ift.WienerFilterCurvature(R, N, S, iteration_controller=IC,
iteration_controller_sampling=IC, iteration_controller_sampling=IC)
data_sampling_dtype=np.float64,
prior_sampling_dtype=np.float64)
m = curv.inverse(required_result) m = curv.inverse(required_result)
assert_allclose( assert_allclose(
m.val, m.val,
...@@ -101,13 +99,11 @@ def test_WF_curvature(space): ...@@ -101,13 +99,11 @@ def test_WF_curvature(space):
if len(space.shape) == 1: if len(space.shape) == 1:
R = ift.ValueInserter(space, [0]) R = ift.ValueInserter(space, [0])
n = ift.from_random(R.domain, 'uniform') + 0.5 n = ift.from_random(R.domain, 'uniform') + 0.5
N = ift.DiagonalOperator(n) N = ift.SamplingDtypeSetter(ift.DiagonalOperator(n), np.float64)
all_diag = 1./s + R(1/n) all_diag = 1./s + R(1/n)
curv = ift.WienerFilterCurvature(R.adjoint, N, S, curv = ift.WienerFilterCurvature(R.adjoint, N, S,
iteration_controller=IC, iteration_controller=IC,
iteration_controller_sampling=IC, iteration_controller_sampling=IC)
data_sampling_dtype=np.float64,
prior_sampling_dtype=np.float64)
m = curv.inverse(required_result) m = curv.inverse(required_result)
assert_allclose( assert_allclose(
m.val, m.val,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment