diff --git a/demos/getting_started_0.ipynb b/demos/getting_started_0.ipynb index cce66e1f22b5151cb76e5fdcc6654962b45ea45c..8ef325af6d5f3d5587acebd48fb01e47a960606e 100644 --- a/demos/getting_started_0.ipynb +++ b/demos/getting_started_0.ipynb @@ -171,7 +171,7 @@ " tol_abs_gradnorm=0.1)\n", " # WienerFilterCurvature is (R.adjoint*N.inverse*R + Sh.inverse) plus some handy\n", " # helper methods.\n", - " return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC, data_sampling_dtype=np.float64, prior_sampling_dtype=np.float64)" + " return ift.WienerFilterCurvature(R,N,Sh,iteration_controller=IC,iteration_controller_sampling=IC)" ] }, { diff --git a/nifty6/library/wiener_filter_curvature.py b/nifty6/library/wiener_filter_curvature.py index 7895878a72ec09967d63379c1d5320cd4a41fd9f..d69bcb5c144fd5a16d9835273e87fcbc2dd7c5e9 100644 --- a/nifty6/library/wiener_filter_curvature.py +++ b/nifty6/library/wiener_filter_curvature.py @@ -15,6 +15,8 @@ # # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. +import numpy as np + from ..operators.inversion_enabler import InversionEnabler from ..operators.sampling_enabler import SamplingDtypeSetter, SamplingEnabler from ..operators.sandwich_operator import SandwichOperator @@ -22,8 +24,8 @@ from ..operators.sandwich_operator import SandwichOperator def WienerFilterCurvature(R, N, S, iteration_controller=None, iteration_controller_sampling=None, - data_sampling_dtype=None, - prior_sampling_dtype=None): + data_sampling_dtype=np.float64, + prior_sampling_dtype=np.float64): """The curvature of the WienerFilterEnergy. This operator implements the second derivative of the @@ -44,6 +46,13 @@ def WienerFilterCurvature(R, N, S, iteration_controller=None, ConjugateGradient. iteration_controller_sampling : IterationController The iteration controller to use for sampling. + data_sampling_dtype : numpy.dtype or dict of numpy.dtype + Data type used for sampling from likelihood. Conincides with the data + type of the data used in the inference problem. Default is float64. + prior_sampling_dtype : numpy.dtype or dict of numpy.dtype + Data type used for sampling from likelihood. Coincides with the data + type of the parameters of the forward model used for the inference + problem. Default is float64. """ Ninv = N.inverse Sinv = S.inverse