Skip to content
Snippets Groups Projects
Commit c59be760 authored by Philipp Arras's avatar Philipp Arras
Browse files

Make Sampling enabler more general

parent 5d5c56ac
No related branches found
No related tags found
1 merge request!333Operator spectra
Pipeline #61912 passed
...@@ -42,16 +42,20 @@ class SamplingEnabler(EndomorphicOperator): ...@@ -42,16 +42,20 @@ class SamplingEnabler(EndomorphicOperator):
operator, which supports the operation modes that the operator doesn't operator, which supports the operation modes that the operator doesn't
have. It is used as a preconditioner during the iterative inversion, have. It is used as a preconditioner during the iterative inversion,
to accelerate convergence. to accelerate convergence.
start_from_zero : boolean
If true, the conjugate gradient algorithm starts from a field filled
with zeros. Otherwise, it starts from a prior samples. Default is
False.
""" """
def __init__(self, likelihood, prior, iteration_controller, def __init__(self, likelihood, prior, iteration_controller,
approximation=None): approximation=None, start_from_zero=False):
self._op = likelihood + prior
# FIXME Separation in likelihood and prior not necessary
self._likelihood = likelihood self._likelihood = likelihood
self._prior = prior self._prior = prior
self._ic = iteration_controller self._ic = iteration_controller
self._approximation = approximation self._approximation = approximation
self._start_from_zero = bool(start_from_zero)
self._op = likelihood + prior
self._domain = self._op.domain self._domain = self._op.domain
self._capability = self._op.capability self._capability = self._op.capability
...@@ -61,8 +65,15 @@ class SamplingEnabler(EndomorphicOperator): ...@@ -61,8 +65,15 @@ class SamplingEnabler(EndomorphicOperator):
except NotImplementedError: except NotImplementedError:
if not from_inverse: if not from_inverse:
raise ValueError("from_inverse must be True here") raise ValueError("from_inverse must be True here")
b = self._op.draw_sample() if self._start_from_zero:
energy = QuadraticEnergy(0*b, self._op, b) b = self._op.draw_sample()
energy = QuadraticEnergy(0*b, self._op, b)
else:
s = self._prior.draw_sample(from_inverse=True)
sp = self._prior(s)
nj = self._likelihood.draw_sample()
energy = QuadraticEnergy(s, self._op, sp + nj,
_grad=self._likelihood(s) - nj)
inverter = ConjugateGradient(self._ic) inverter = ConjugateGradient(self._ic)
if self._approximation is not None: if self._approximation is not None:
energy, convergence = inverter( energy, convergence = inverter(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment