Commit fef8ef1e authored by Philipp Arras's avatar Philipp Arras
Browse files

Reproduce stdnormal samples

parent 4d2c5ffd
Pipeline #103270 passed with stages
in 14 minutes and 1 second
......@@ -107,6 +107,10 @@ class MeanFieldVI:
def entropy(self):
return self._entropy.force(self._KL.position)
@property
def KL(self):
return self._KL
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
from_random(self._samdom))
......@@ -197,6 +201,10 @@ class FullCovarianceVI:
def entropy(self):
return self._entropy.force(self._KL.position)
@property
def KL(self):
return self._KL
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
from_random(self._samdom))
......
......@@ -107,7 +107,7 @@ class StochasticEnergyAdapter(Energy):
but rather via the factory function :attr:`make`.
"""
def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf,
_callingfrommake=False):
sampling_context, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(StochasticEnergyAdapter, self).__init__(position)
......@@ -130,6 +130,7 @@ class StochasticEnergyAdapter(Energy):
self._op = op
self._keys = keys
self._context = sampling_context
@property
def value(self):
......@@ -142,7 +143,7 @@ class StochasticEnergyAdapter(Energy):
def at(self, position):
return StochasticEnergyAdapter(position, self._op, self._keys,
self._local_ops, self._n_samples, self._comm, self._nanisinf,
_callingfrommake=True)
self._context, _callingfrommake=True)
def apply_metric(self, x):
lin = Linearization.make_var(self.position, want_metric=True)
......@@ -161,8 +162,8 @@ class StochasticEnergyAdapter(Energy):
return StochasticEnergyAdapter.make(position, self._op, self._keys,
self._n_samples, self._comm)
@staticmethod
def make(position, op, sampling_keys, n_samples, mirror_samples,
@classmethod
def make(cls, position, op, sampling_keys, n_samples, mirror_samples,
comm=None, nanisinf=False):
"""Factory function for StochasticEnergyAdapter.
......@@ -201,17 +202,26 @@ class StochasticEnergyAdapter(Energy):
raise ValueError
samdom[k] = op.domain[k]
samdom = MultiDomain.make(samdom)
local_ops = []
sseq = random.spawn_sseq(n_samples)
context = op, position, sseq, comm, n_samples, mirror_samples, samdom
noise = cls._draw_noise(*context)
local_ops = [op.simplify_for_constant_input(nn)[1] for nn in noise]
n_samples = 2*n_samples if mirror_samples else n_samples
return StochasticEnergyAdapter(position, op, sampling_keys, local_ops,
n_samples, comm, nanisinf, context, _callingfrommake=True)
@staticmethod
def _draw_noise(op, position, sseq, comm, n_samples, mirror_samples, sample_domain):
from .kl_energies import _get_lo_hi
noise = []
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
rnd = from_random(samdom)
_, tmp = op.simplify_for_constant_input(rnd)
myassert(tmp.domain == position.domain)
local_ops.append(tmp)
rnd = from_random(sample_domain)
noise.append(rnd)
if mirror_samples:
local_ops.append(op.simplify_for_constant_input(-rnd)[1])
n_samples = 2*n_samples if mirror_samples else n_samples
return StochasticEnergyAdapter(position, op, sampling_keys, local_ops,
n_samples, comm, nanisinf, _callingfrommake=True)
noise.append(-rnd)
return noise
def samples(self):
"""Standard-normal samples that have been inserted into `op`"""
return self._draw_noise(*self._context)
......@@ -115,19 +115,26 @@ def test_ParametricVI(mirror_samples, fc):
ic.enable_logging()
h = ift.StandardHamiltonian(lh, ic_samp=ic)
initial_mean = ift.from_random(h.domain, 'normal')
nsamps = 1000
nsamps = 10
args = initial_mean, h, nsamps, mirror_samples, 0.01
model = (ift.FullCovarianceVI if fc else ift.MeanFieldVI)(*args)
kl = model._KL
kl = model.KL
expected_nsamps = 2*nsamps if mirror_samples else nsamps
myassert(len(tuple(kl._local_ops)) == expected_nsamps)
true_val = []
for i in range(expected_nsamps):
lat_rnd = ift.from_random(model._KL._op.domain['latent'])
lat_rnd = ift.from_random(model.KL._op.domain['latent'])
samp = kl.position.to_dict()
samp['latent'] = lat_rnd
samp = ift.MultiField.from_dict(samp)
true_val.append(model._KL._op(samp))
true_val.append(model.KL._op(samp))
true_val = sum(true_val)/expected_nsamps
assert_allclose(true_val.val, kl.value, rtol=0.1)
samples = model.KL.samples()
mini = ift.SteepestDescent(ift.GradientNormController(iteration_limit=3))
model.minimize(mini)
samples1 = model.KL.samples()
for aa, bb in zip(samples, samples1):
ift.extra.assert_allclose(aa, bb)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment