diff --git a/nifty4/library/nonlinear_power_energy.py b/nifty4/library/nonlinear_power_energy.py index 8ba6c03394e1618bae710b0f6981f657fb7b52ff..4eaca6d9fbf47d50c0d68ee00c20be23ab57cbd0 100644 --- a/nifty4/library/nonlinear_power_energy.py +++ b/nifty4/library/nonlinear_power_energy.py @@ -70,7 +70,7 @@ class NonlinearPowerEnergy(Energy): if samples is None or samples == 0: xi_sample_list = [xi] else: - xi_sample_list = [D.draw_inverse_sample() + xi + xi_sample_list = [D.inverse_draw_sample() + xi for _ in range(samples)] self.xi_sample_list = xi_sample_list self.inverter = inverter diff --git a/nifty4/library/wiener_filter_curvature.py b/nifty4/library/wiener_filter_curvature.py index 3ab456a46a614dcec30915b9b97f58d076022f08..f4d7cd4e0b36829a6710aa388beb155fac9db5bd 100644 --- a/nifty4/library/wiener_filter_curvature.py +++ b/nifty4/library/wiener_filter_curvature.py @@ -16,12 +16,12 @@ # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik # and financially supported by the Studienstiftung des deutschen Volkes. -from ..operators.endomorphic_operator import EndomorphicOperator +from ..operators.sandwich_operator import SandwichOperator from ..operators.inversion_enabler import InversionEnabler import numpy as np -class WienerFilterCurvature(EndomorphicOperator): +def WienerFilterCurvature(R, N, S, inverter): """The curvature of the WienerFilterEnergy. This operator implements the second derivative of the @@ -40,28 +40,5 @@ class WienerFilterCurvature(EndomorphicOperator): inverter : Minimizer The minimizer to use during numerical inversion """ - - def __init__(self, R, N, S, inverter): - super(WienerFilterCurvature, self).__init__() - self.R = R - self.N = N - self.S = S - op = R.adjoint*N.inverse*R + S.inverse - self._op = InversionEnabler(op, inverter, S.times) - - @property - def domain(self): - return self._op.domain - - @property - def capability(self): - return self._op.capability - - def apply(self, x, mode): - return self._op.apply(x, mode) - - def draw_sample(self, dtype=np.float64): - n = self.N.inverse_draw_sample(dtype) - s = self.S.inverse_draw_sample(dtype) - - return s - self.R.adjoint_times(n) + op = SandwichOperator(R, N.inverse) + S.inverse + return InversionEnabler(op, inverter, S.times) diff --git a/nifty4/operators/diagonal_operator.py b/nifty4/operators/diagonal_operator.py index 7d4a2dd829686623e2f835f785666190d05ecc3a..03e7a9934c16744e8e2aab848f86d8735e08671d 100644 --- a/nifty4/operators/diagonal_operator.py +++ b/nifty4/operators/diagonal_operator.py @@ -147,7 +147,8 @@ class DiagonalOperator(EndomorphicOperator): return res def draw_sample(self, dtype=np.float64): - if np.issubdtype(self._ldiag.dtype, np.complexfloating) or (self._ldiag <= 0.).any(): + if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or + (self._ldiag <= 0.).any()): raise ValueError("operator not positive definite") res = Field.from_random(random_type="normal", domain=self._domain, dtype=dtype) @@ -155,7 +156,8 @@ class DiagonalOperator(EndomorphicOperator): return res def inverse_draw_sample(self, dtype=np.float64): - if np.issubdtype(self._ldiag.dtype, np.complexfloating) or (self._ldiag <= 0.).any(): + if (np.issubdtype(self._ldiag.dtype, np.complexfloating) or + (self._ldiag <= 0.).any()): raise ValueError("operator not positive definite") res = Field.from_random(random_type="normal", domain=self._domain, diff --git a/nifty4/operators/endomorphic_operator.py b/nifty4/operators/endomorphic_operator.py index 0c40ebb98873b07cfdcea061af6fe205c2183878..e456923c0d9ece42b7eb3e2d477a6272c630fa53 100644 --- a/nifty4/operators/endomorphic_operator.py +++ b/nifty4/operators/endomorphic_operator.py @@ -60,7 +60,7 @@ class EndomorphicOperator(LinearOperator): A sample from the Gaussian of given covariance """ if self.capability & self.INVERSE_TIMES: - x = self.draw_sample(dtype=dtype) + x = self.draw_sample(dtype) return self.inverse_times(x) else: raise NotImplementedError diff --git a/nifty4/operators/inverse_operator.py b/nifty4/operators/inverse_operator.py index cca961317b73181b249c901377a3ba151fa4d968..655c3f91ec3bab3879dafba15fd62847369e127c 100644 --- a/nifty4/operators/inverse_operator.py +++ b/nifty4/operators/inverse_operator.py @@ -17,6 +17,7 @@ # and financially supported by the Studienstiftung des deutschen Volkes. from .linear_operator import LinearOperator +import numpy as np class InverseOperator(LinearOperator): @@ -44,3 +45,9 @@ class InverseOperator(LinearOperator): def apply(self, x, mode): return self._op.apply(x, self._inverseMode[mode]) + + def draw_sample(self, dtype=np.float64): + return self._op.inverse_draw_sample(dtype) + + def inverse_draw_sample(self, dtype=np.float64): + return self._op.draw_sample(dtype) diff --git a/nifty4/operators/inversion_enabler.py b/nifty4/operators/inversion_enabler.py index ea6a436c4d24de356aff614c683ddf8f42b06b6e..72aaa24641b3fb9eba784668e8c5d2829522845e 100644 --- a/nifty4/operators/inversion_enabler.py +++ b/nifty4/operators/inversion_enabler.py @@ -21,6 +21,7 @@ from ..minimization.iteration_controller import IterationController from ..field import Field from ..logger import logger from .linear_operator import LinearOperator +import numpy as np class InversionEnabler(LinearOperator): @@ -74,3 +75,15 @@ class InversionEnabler(LinearOperator): if stat != IterationController.CONVERGED: logger.warning("Error detected during operator inversion") return r.position + + def draw_sample(self, dtype=np.float64): + try: + return self._op.draw_sample(dtype) + except: + return self(self._op.inverse_draw_sample(dtype)) + + def inverse_draw_sample(self, dtype=np.float64): + try: + return self._op.inverse_draw_sample(dtype) + except: + return self.inverse_times(self._op.draw_sample(dtype)) diff --git a/nifty4/operators/sandwich_operator.py b/nifty4/operators/sandwich_operator.py index 65d66bb6ac425f4c3a650a50cf1296eebd01940c..e324b135db40822ee10204f0db31442550fc7d32 100644 --- a/nifty4/operators/sandwich_operator.py +++ b/nifty4/operators/sandwich_operator.py @@ -17,6 +17,7 @@ # and financially supported by the Studienstiftung des deutschen Volkes. from .endomorphic_operator import EndomorphicOperator +import numpy as np class SandwichOperator(EndomorphicOperator): @@ -47,5 +48,5 @@ class SandwichOperator(EndomorphicOperator): def apply(self, x, mode): return self._op.apply(x, mode) - def draw_sample(self): - return self._bun.adjoint_times(self._cheese.draw_sample()) + def draw_sample(self, dtype=np.float64): + return self._bun.adjoint_times(self._cheese.draw_sample(dtype)) diff --git a/nifty4/operators/sum_operator.py b/nifty4/operators/sum_operator.py index 6e8a253c19b27cdc02a54aeeae190d0e82b3f52c..c4e281099c3d6af7dba46664f777ffb154108f2c 100644 --- a/nifty4/operators/sum_operator.py +++ b/nifty4/operators/sum_operator.py @@ -17,6 +17,7 @@ # and financially supported by the Studienstiftung des deutschen Volkes. from .linear_operator import LinearOperator +import numpy as np class SumOperator(LinearOperator): @@ -143,3 +144,9 @@ class SumOperator(LinearOperator): else: res += op.apply(x, mode) return res + + def draw_sample(self, dtype=np.float64): + res = self._ops[0].draw_sample(dtype) + for op in self._ops[1:]: + res += op.draw_sample(dtype) + return res diff --git a/nifty4/probing/utils.py b/nifty4/probing/utils.py index e11b2bd7cc2ed9b6fcf3cf18a2d7bb892f08220e..27aa1b9dba7b5df75ef91ae1dfde03d617d9c526 100644 --- a/nifty4/probing/utils.py +++ b/nifty4/probing/utils.py @@ -51,7 +51,7 @@ class StatCalculator(object): def probe_with_posterior_samples(op, post_op, nprobes): sc = StatCalculator() for i in range(nprobes): - sample = post_op(op.draw_inverse_sample()) + sample = post_op(op.inverse_draw_sample()) sc.add(sample) if nprobes == 1: diff --git a/test/test_energies/test_noise.py b/test/test_energies/test_noise.py index f6ddf18747953ca34d62a455396a4b185f5579db..cd53bde89c95e56e697882efaf81525c3fb277de 100644 --- a/test/test_energies/test_noise.py +++ b/test/test_energies/test_noise.py @@ -84,7 +84,7 @@ class Noise_Energy_Tests(unittest.TestCase): S=S, inverter=inverter).curvature - res_sample_list = [d - R(f(ht(C.draw_inverse_sample() + xi))) + res_sample_list = [d - R(f(ht(C.inverse_draw_sample() + xi))) for _ in range(10)] energy0 = ift.library.NoiseEnergy(eta0, alpha, q, res_sample_list)