Commit 2cbf787a authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge branch 'stochastic_samples' into 'more_samplers'

Reproduce stdnormal samples

See merge request ift/nifty!643
parents 54d832c6 44d62ab7
Pipeline #103404 passed with stages
in 15 minutes and 11 seconds
......@@ -107,6 +107,10 @@ class MeanFieldVI:
def entropy(self):
return self._entropy.force(self._KL.position)
@property
def KL(self):
return self._KL
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
from_random(self._samdom))
......@@ -197,6 +201,10 @@ class FullCovarianceVI:
def entropy(self):
return self._entropy.force(self._KL.position)
@property
def KL(self):
return self._KL
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
from_random(self._samdom))
......
......@@ -107,7 +107,7 @@ class StochasticEnergyAdapter(Energy):
but rather via the factory function :attr:`make`.
"""
def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf,
_callingfrommake=False):
noise, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(StochasticEnergyAdapter, self).__init__(position)
......@@ -127,6 +127,7 @@ class StochasticEnergyAdapter(Energy):
if np.isnan(self._val) and self._nanisinf:
self._val = np.inf
self._grad = allreduce_sum(g, self._comm)/self._n_samples
self._noise = noise
self._op = op
self._keys = keys
......@@ -142,7 +143,7 @@ class StochasticEnergyAdapter(Energy):
def at(self, position):
return StochasticEnergyAdapter(position, self._op, self._keys,
self._local_ops, self._n_samples, self._comm, self._nanisinf,
_callingfrommake=True)
self._noise, _callingfrommake=True)
def apply_metric(self, x):
lin = Linearization.make_var(self.position, want_metric=True)
......@@ -201,17 +202,23 @@ class StochasticEnergyAdapter(Energy):
raise ValueError
samdom[k] = op.domain[k]
samdom = MultiDomain.make(samdom)
local_ops = []
noise = []
sseq = random.spawn_sseq(n_samples)
from .kl_energies import _get_lo_hi
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
rnd = from_random(samdom)
_, tmp = op.simplify_for_constant_input(rnd)
myassert(tmp.domain == position.domain)
local_ops.append(tmp)
noise.append(rnd)
if mirror_samples:
local_ops.append(op.simplify_for_constant_input(-rnd)[1])
noise.append(-rnd)
local_ops = []
for nn in noise:
_, tmp = op.simplify_for_constant_input(nn)
myassert(tmp.domain == position.domain)
local_ops.append(tmp)
n_samples = 2*n_samples if mirror_samples else n_samples
return StochasticEnergyAdapter(position, op, sampling_keys, local_ops,
n_samples, comm, nanisinf, _callingfrommake=True)
n_samples, comm, nanisinf, noise, _callingfrommake=True)
def samples(self):
return self._noise
......@@ -115,19 +115,25 @@ def test_ParametricVI(mirror_samples, fc):
ic.enable_logging()
h = ift.StandardHamiltonian(lh, ic_samp=ic)
initial_mean = ift.from_random(h.domain, 'normal')
nsamps = 1000
nsamps = 10
args = initial_mean, h, nsamps, mirror_samples, 0.01
model = (ift.FullCovarianceVI if fc else ift.MeanFieldVI)(*args)
kl = model._KL
kl = model.KL
expected_nsamps = 2*nsamps if mirror_samples else nsamps
myassert(len(tuple(kl._local_ops)) == expected_nsamps)
true_val = []
for i in range(expected_nsamps):
lat_rnd = ift.from_random(model._KL._op.domain['latent'])
lat_rnd = ift.from_random(model.KL._op.domain['latent'])
samp = kl.position.to_dict()
samp['latent'] = lat_rnd
samp = ift.MultiField.from_dict(samp)
true_val.append(model._KL._op(samp))
true_val.append(model.KL._op(samp))
true_val = sum(true_val)/expected_nsamps
assert_allclose(true_val.val, kl.value, rtol=0.1)
samples = model.KL.samples()
ift.random.current_rng().integers(0, 100)
samples1 = model.KL.samples()
for aa, bb in zip(samples, samples1):
ift.extra.assert_allclose(aa, bb)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment