From c59be7601e218add56e359a01c824798d98bf2d6 Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Tue, 15 Oct 2019 17:53:44 +0200 Subject: [PATCH] Make Sampling enabler more general --- nifty5/operators/sampling_enabler.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/nifty5/operators/sampling_enabler.py b/nifty5/operators/sampling_enabler.py index 2ceec16e2..883b8f3f6 100644 --- a/nifty5/operators/sampling_enabler.py +++ b/nifty5/operators/sampling_enabler.py @@ -42,16 +42,20 @@ class SamplingEnabler(EndomorphicOperator): operator, which supports the operation modes that the operator doesn't have. It is used as a preconditioner during the iterative inversion, to accelerate convergence. + start_from_zero : boolean + If true, the conjugate gradient algorithm starts from a field filled + with zeros. Otherwise, it starts from a prior samples. Default is + False. """ def __init__(self, likelihood, prior, iteration_controller, - approximation=None): - self._op = likelihood + prior - # FIXME Separation in likelihood and prior not necessary + approximation=None, start_from_zero=False): self._likelihood = likelihood self._prior = prior self._ic = iteration_controller self._approximation = approximation + self._start_from_zero = bool(start_from_zero) + self._op = likelihood + prior self._domain = self._op.domain self._capability = self._op.capability @@ -61,8 +65,15 @@ class SamplingEnabler(EndomorphicOperator): except NotImplementedError: if not from_inverse: raise ValueError("from_inverse must be True here") - b = self._op.draw_sample() - energy = QuadraticEnergy(0*b, self._op, b) + if self._start_from_zero: + b = self._op.draw_sample() + energy = QuadraticEnergy(0*b, self._op, b) + else: + s = self._prior.draw_sample(from_inverse=True) + sp = self._prior(s) + nj = self._likelihood.draw_sample() + energy = QuadraticEnergy(s, self._op, sp + nj, + _grad=self._likelihood(s) - nj) inverter = ConjugateGradient(self._ic) if self._approximation is not None: energy, convergence = inverter( -- GitLab