Commit 68407686 authored by Philipp Arras's avatar Philipp Arras
Browse files

Remove default for mirror_samples

Closes #314

* Use mirrored_samples in demos

* Use minimized samples for posterior analysis. This is our general
practise and also applied in all other demos.
parent 9cbf7497
......@@ -50,7 +50,8 @@ MetricGaussianKL interface
--------------------------
Users do not instantiate `MetricGaussianKL` by its constructor anymore. Rather
`MetricGaussianKL.make()` shall be used.
`MetricGaussianKL.make()` shall be used. Additionally, `mirror_samples` is not
set by default anymore.
Changes since NIFTy 5
......
......@@ -123,7 +123,7 @@ def main():
# Draw new samples to approximate the KL five times
for i in range(5):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True)
KL, convergence = minimizer(KL)
mean = KL.position
ift.extra.minisanity(KL, data, sig.inverse, signal_response)
......
......@@ -129,9 +129,9 @@ def main():
initial_mean = ift.MultiField.full(H.domain, 0.)
mean = initial_mean
for i in range(10):
for i in range(5):
# Draw new samples and minimize KL
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
KL = ift.MetricGaussianKL.make(mean, H, N_samples, True)
KL, convergence = minimizer(KL)
mean = KL.position
......@@ -157,7 +157,6 @@ def main():
name=filename.format("loop_{:02d}".format(i)))
# Done, draw posterior samples
KL = ift.MetricGaussianKL.make(mean, H, N_samples)
sc = ift.StatCalculator()
scA1 = ift.StatCalculator()
scA2 = ift.StatCalculator()
......
......@@ -91,7 +91,7 @@ def main():
plt.figure(figsize=[12, 8])
for ii in range(15):
if ii % 3 == 0:
mgkl = ift.MetricGaussianKL.make(pos, ham, 40)
mgkl = ift.MetricGaussianKL.make(pos, ham, 40, False)
plt.cla()
plt.imshow(z.T, origin='lower', norm=LogNorm(), vmin=1e-3,
......
......@@ -126,8 +126,8 @@ class MetricGaussianKL(Energy):
self._grad = utilities.allreduce_sum(g, self._comm)/self.n_eff_samples
@staticmethod
def make(mean, hamiltonian, n_samples, constants=[], point_estimates=[],
mirror_samples=False, napprox=0, comm=None, nanisinf=False):
def make(mean, hamiltonian, n_samples, mirror_samples, constants=[],
point_estimates=[], napprox=0, comm=None, nanisinf=False):
"""Return instance of :class:`MetricGaussianKL`.
Parameters
......@@ -138,6 +138,12 @@ class MetricGaussianKL(Energy):
Hamiltonian of the approximated probability distribution.
n_samples : integer
Number of samples used to stochastically estimate the KL.
mirror_samples : boolean
Whether the negative of the drawn samples are also used, as they are
equally legitimate samples. If true, the number of used samples
doubles. Mirroring samples stabilizes the KL estimate as extreme
sample variation is counterbalanced. Since it improves stability in
many cases, it is recommended to set `mirror_samples` to `True`.
constants : list
List of parameter keys that are kept constant during optimization.
Default is no constants.
......@@ -145,11 +151,6 @@ class MetricGaussianKL(Energy):
List of parameter keys for which no samples are drawn, but that are
(possibly) optimized for, corresponding to point estimates of these.
Default is to draw samples for the complete domain.
mirror_samples : boolean
Whether the negative of the drawn samples are also used,
as they are equally legitimate samples. If true, the number of used
samples doubles. Mirroring samples stabilizes the KL estimate as
extreme sample variation is counterbalanced. Default is False.
napprox : int
Number of samples for computing preconditioner for sampling. No
preconditioning is done by default.
......
......@@ -524,7 +524,7 @@ def calculate_position(operator, output):
minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos'))
for ii in range(3):
logger.info(f'Start iteration {ii+1}/3')
kl = MetricGaussianKL.make(pos, H, 3, mirror_samples=True)
kl = MetricGaussianKL.make(pos, H, 3, True)
kl, _ = minimizer(kl)
pos = kl.position
return pos
......
......@@ -49,7 +49,7 @@ def test_exec_time():
lh = ift.GaussianEnergy(domain=op.target, sampling_dtype=np.float64) @ op1
ic = ift.GradientNormController(iteration_limit=2)
ham = ift.StandardHamiltonian(lh, ic_samp=ic)
kl = ift.MetricGaussianKL.make(ift.full(ham.domain, 0.), ham, 1)
kl = ift.MetricGaussianKL.make(ift.full(ham.domain, 0.), ham, 1, False)
ops = [op, op1, lh, ham, kl]
for oo in ops:
for wm in [True, False]:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment