Commit fef8ef1e authored by Philipp Arras's avatar Philipp Arras
Browse files

Reproduce stdnormal samples

parent 4d2c5ffd
...@@ -107,6 +107,10 @@ class MeanFieldVI: ...@@ -107,6 +107,10 @@ class MeanFieldVI:
def entropy(self): def entropy(self):
return self._entropy.force(self._KL.position) return self._entropy.force(self._KL.position)
@property
def KL(self):
return self._KL
def draw_sample(self): def draw_sample(self):
_, op = self._generator.simplify_for_constant_input( _, op = self._generator.simplify_for_constant_input(
from_random(self._samdom)) from_random(self._samdom))
...@@ -197,6 +201,10 @@ class FullCovarianceVI: ...@@ -197,6 +201,10 @@ class FullCovarianceVI:
def entropy(self): def entropy(self):
return self._entropy.force(self._KL.position) return self._entropy.force(self._KL.position)
@property
def KL(self):
return self._KL
def draw_sample(self): def draw_sample(self):
_, op = self._generator.simplify_for_constant_input( _, op = self._generator.simplify_for_constant_input(
from_random(self._samdom)) from_random(self._samdom))
......
...@@ -107,7 +107,7 @@ class StochasticEnergyAdapter(Energy): ...@@ -107,7 +107,7 @@ class StochasticEnergyAdapter(Energy):
but rather via the factory function :attr:`make`. but rather via the factory function :attr:`make`.
""" """
def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf, def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf,
_callingfrommake=False): sampling_context, _callingfrommake=False):
if not _callingfrommake: if not _callingfrommake:
raise NotImplementedError raise NotImplementedError
super(StochasticEnergyAdapter, self).__init__(position) super(StochasticEnergyAdapter, self).__init__(position)
...@@ -130,6 +130,7 @@ class StochasticEnergyAdapter(Energy): ...@@ -130,6 +130,7 @@ class StochasticEnergyAdapter(Energy):
self._op = op self._op = op
self._keys = keys self._keys = keys
self._context = sampling_context
@property @property
def value(self): def value(self):
...@@ -142,7 +143,7 @@ class StochasticEnergyAdapter(Energy): ...@@ -142,7 +143,7 @@ class StochasticEnergyAdapter(Energy):
def at(self, position): def at(self, position):
return StochasticEnergyAdapter(position, self._op, self._keys, return StochasticEnergyAdapter(position, self._op, self._keys,
self._local_ops, self._n_samples, self._comm, self._nanisinf, self._local_ops, self._n_samples, self._comm, self._nanisinf,
_callingfrommake=True) self._context, _callingfrommake=True)
def apply_metric(self, x): def apply_metric(self, x):
lin = Linearization.make_var(self.position, want_metric=True) lin = Linearization.make_var(self.position, want_metric=True)
...@@ -161,8 +162,8 @@ class StochasticEnergyAdapter(Energy): ...@@ -161,8 +162,8 @@ class StochasticEnergyAdapter(Energy):
return StochasticEnergyAdapter.make(position, self._op, self._keys, return StochasticEnergyAdapter.make(position, self._op, self._keys,
self._n_samples, self._comm) self._n_samples, self._comm)
@staticmethod @classmethod
def make(position, op, sampling_keys, n_samples, mirror_samples, def make(cls, position, op, sampling_keys, n_samples, mirror_samples,
comm=None, nanisinf=False): comm=None, nanisinf=False):
"""Factory function for StochasticEnergyAdapter. """Factory function for StochasticEnergyAdapter.
...@@ -201,17 +202,26 @@ class StochasticEnergyAdapter(Energy): ...@@ -201,17 +202,26 @@ class StochasticEnergyAdapter(Energy):
raise ValueError raise ValueError
samdom[k] = op.domain[k] samdom[k] = op.domain[k]
samdom = MultiDomain.make(samdom) samdom = MultiDomain.make(samdom)
local_ops = []
sseq = random.spawn_sseq(n_samples) sseq = random.spawn_sseq(n_samples)
context = op, position, sseq, comm, n_samples, mirror_samples, samdom
noise = cls._draw_noise(*context)
local_ops = [op.simplify_for_constant_input(nn)[1] for nn in noise]
n_samples = 2*n_samples if mirror_samples else n_samples
return StochasticEnergyAdapter(position, op, sampling_keys, local_ops,
n_samples, comm, nanisinf, context, _callingfrommake=True)
@staticmethod
def _draw_noise(op, position, sseq, comm, n_samples, mirror_samples, sample_domain):
from .kl_energies import _get_lo_hi from .kl_energies import _get_lo_hi
noise = []
for i in range(*_get_lo_hi(comm, n_samples)): for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]): with random.Context(sseq[i]):
rnd = from_random(samdom) rnd = from_random(sample_domain)
_, tmp = op.simplify_for_constant_input(rnd) noise.append(rnd)
myassert(tmp.domain == position.domain)
local_ops.append(tmp)
if mirror_samples: if mirror_samples:
local_ops.append(op.simplify_for_constant_input(-rnd)[1]) noise.append(-rnd)
n_samples = 2*n_samples if mirror_samples else n_samples return noise
return StochasticEnergyAdapter(position, op, sampling_keys, local_ops,
n_samples, comm, nanisinf, _callingfrommake=True) def samples(self):
"""Standard-normal samples that have been inserted into `op`"""
return self._draw_noise(*self._context)
...@@ -115,19 +115,26 @@ def test_ParametricVI(mirror_samples, fc): ...@@ -115,19 +115,26 @@ def test_ParametricVI(mirror_samples, fc):
ic.enable_logging() ic.enable_logging()
h = ift.StandardHamiltonian(lh, ic_samp=ic) h = ift.StandardHamiltonian(lh, ic_samp=ic)
initial_mean = ift.from_random(h.domain, 'normal') initial_mean = ift.from_random(h.domain, 'normal')
nsamps = 1000 nsamps = 10
args = initial_mean, h, nsamps, mirror_samples, 0.01 args = initial_mean, h, nsamps, mirror_samples, 0.01
model = (ift.FullCovarianceVI if fc else ift.MeanFieldVI)(*args) model = (ift.FullCovarianceVI if fc else ift.MeanFieldVI)(*args)
kl = model._KL kl = model.KL
expected_nsamps = 2*nsamps if mirror_samples else nsamps expected_nsamps = 2*nsamps if mirror_samples else nsamps
myassert(len(tuple(kl._local_ops)) == expected_nsamps) myassert(len(tuple(kl._local_ops)) == expected_nsamps)
true_val = [] true_val = []
for i in range(expected_nsamps): for i in range(expected_nsamps):
lat_rnd = ift.from_random(model._KL._op.domain['latent']) lat_rnd = ift.from_random(model.KL._op.domain['latent'])
samp = kl.position.to_dict() samp = kl.position.to_dict()
samp['latent'] = lat_rnd samp['latent'] = lat_rnd
samp = ift.MultiField.from_dict(samp) samp = ift.MultiField.from_dict(samp)
true_val.append(model._KL._op(samp)) true_val.append(model.KL._op(samp))
true_val = sum(true_val)/expected_nsamps true_val = sum(true_val)/expected_nsamps
assert_allclose(true_val.val, kl.value, rtol=0.1) assert_allclose(true_val.val, kl.value, rtol=0.1)
samples = model.KL.samples()
mini = ift.SteepestDescent(ift.GradientNormController(iteration_limit=3))
model.minimize(mini)
samples1 = model.KL.samples()
for aa, bb in zip(samples, samples1):
ift.extra.assert_allclose(aa, bb)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment