Commit 44d62ab7 authored by Philipp Arras's avatar Philipp Arras
Browse files

Just store the samples

parent 0427c08e
Pipeline #103383 passed with stages
in 13 minutes and 55 seconds
......@@ -107,7 +107,7 @@ class StochasticEnergyAdapter(Energy):
but rather via the factory function :attr:`make`.
def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf,
sampling_context, _callingfrommake=False):
noise, _callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
super(StochasticEnergyAdapter, self).__init__(position)
......@@ -127,10 +127,10 @@ class StochasticEnergyAdapter(Energy):
if np.isnan(self._val) and self._nanisinf:
self._val = np.inf
self._grad = allreduce_sum(g, self._comm)/self._n_samples
self._noise = noise
self._op = op
self._keys = keys
self._context = sampling_context
def value(self):
......@@ -143,7 +143,7 @@ class StochasticEnergyAdapter(Energy):
def at(self, position):
return StochasticEnergyAdapter(position, self._op, self._keys,
self._local_ops, self._n_samples, self._comm, self._nanisinf,
self._context, _callingfrommake=True)
self._noise, _callingfrommake=True)
def apply_metric(self, x):
lin = Linearization.make_var(self.position, want_metric=True)
......@@ -162,8 +162,8 @@ class StochasticEnergyAdapter(Energy):
return StochasticEnergyAdapter.make(position, self._op, self._keys,
self._n_samples, self._comm)
def make(cls, position, op, sampling_keys, n_samples, mirror_samples,
def make(position, op, sampling_keys, n_samples, mirror_samples,
comm=None, nanisinf=False):
"""Factory function for StochasticEnergyAdapter.
......@@ -202,29 +202,23 @@ class StochasticEnergyAdapter(Energy):
raise ValueError
samdom[k] = op.domain[k]
samdom = MultiDomain.make(samdom)
seed = int(random.current_rng().integers(0, 1000000))
context = op, position, seed, comm, n_samples, mirror_samples, samdom
noise = cls._draw_noise(*context)
local_ops = [op.simplify_for_constant_input(nn)[1] for nn in noise]
n_samples = 2*n_samples if mirror_samples else n_samples
return StochasticEnergyAdapter(position, op, sampling_keys, local_ops,
n_samples, comm, nanisinf, context, _callingfrommake=True)
def _draw_noise(op, position, seed, comm, n_samples, mirror_samples, sample_domain):
from .kl_energies import _get_lo_hi
with random.Context(seed):
sseq = random.spawn_sseq(n_samples)
noise = []
sseq = random.spawn_sseq(n_samples)
from .kl_energies import _get_lo_hi
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
rnd = from_random(sample_domain)
rnd = from_random(samdom)
if mirror_samples:
return noise
local_ops = []
for nn in noise:
_, tmp = op.simplify_for_constant_input(nn)
myassert(tmp.domain == position.domain)
n_samples = 2*n_samples if mirror_samples else n_samples
return StochasticEnergyAdapter(position, op, sampling_keys, local_ops,
n_samples, comm, nanisinf, noise, _callingfrommake=True)
def samples(self):
"""Standard-normal samples that have been inserted into `op`"""
return self._draw_noise(*self._context)
return self._noise
......@@ -133,7 +133,7 @@ def test_ParametricVI(mirror_samples, fc):
assert_allclose(true_val.val, kl.value, rtol=0.1)
samples = model.KL.samples()
ift.random.current_rng().integers(0, 100)
samples1 = model.KL.samples()
for aa, bb in zip(samples, samples1):
ift.extra.assert_allclose(aa, bb)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment