Commit 193a276f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'KL_sample_generator' into 'NIFTy_6'

Use a generator for MetricGaussianKL.samples

See merge request !432
parents c01d10cc 8e377ee6
Pipeline #71475 passed with stages
in 27 minutes and 46 seconds
......@@ -121,7 +121,7 @@ class MetricGaussianKL(Energy):
the presence of this parameter is that metric of the likelihood energy
is just an `Operator` which does not know anything about the dtype of
the fields on which it acts. Default is float64.
_samples : None
_local_samples : None
Only a parameter for internal uses. Typically not to be set by users.
Note
......@@ -138,7 +138,7 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False,
napprox=0, comm=None, _samples=None,
napprox=0, comm=None, _local_samples=None,
lh_sampling_dtype=np.float64):
super(MetricGaussianKL, self).__init__(mean)
......@@ -170,31 +170,31 @@ class MetricGaussianKL(Energy):
if self._mirror_samples:
self._n_eff_samples *= 2
if _samples is None:
if _local_samples is None:
met = hamiltonian(Linearization.make_partial_var(
mean, self._point_estimates, True)).metric
if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox))
_samples = []
_local_samples = []
sseq = random.spawn_sseq(self._n_samples)
for i in range(self._lo, self._hi):
random.push_sseq(sseq[i])
_samples.append(met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype))
_local_samples.append(met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype))
random.pop_sseq()
_samples = tuple(_samples)
_local_samples = tuple(_local_samples)
else:
if len(_samples) != self._hi-self._lo:
if len(_local_samples) != self._hi-self._lo:
raise ValueError("# of samples mismatch")
self._samples = _samples
self._local_samples = _local_samples
self._lin = Linearization.make_partial_var(mean, self._constants)
v, g = None, None
if len(self._samples) == 0: # hack if there are too many MPI tasks
if len(self._local_samples) == 0: # hack if there are too many MPI tasks
tmp = self._hamiltonian(self._lin)
v = 0. * tmp.val.val
g = 0. * tmp.gradient
else:
for s in self._samples:
for s in self._local_samples:
tmp = self._hamiltonian(self._lin+s)
if self._mirror_samples:
tmp = tmp + self._hamiltonian(self._lin-s)
......@@ -213,7 +213,7 @@ class MetricGaussianKL(Energy):
return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, comm=self._comm,
_samples=self._samples, lh_sampling_dtype=self._sampdt)
_local_samples=self._local_samples, lh_sampling_dtype=self._sampdt)
@property
def value(self):
......@@ -226,15 +226,15 @@ class MetricGaussianKL(Energy):
def _get_metric(self):
lin = self._lin.with_want_metric()
if self._metric is None:
if len(self._samples) == 0: # hack if there are too many MPI tasks
if len(self._local_samples) == 0: # hack if there are too many MPI tasks
self._metric = self._hamiltonian(lin).metric.scale(0.)
else:
mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples)
self._local_samples)
unscaled_metric = utilities.my_sum(mymap)
if self._mirror_samples:
mymap = map(lambda v: self._hamiltonian(lin-v).metric,
self._samples)
self._local_samples)
unscaled_metric = unscaled_metric + utilities.my_sum(mymap)
self._metric = unscaled_metric.scale(1./self._n_eff_samples)
......@@ -248,14 +248,22 @@ class MetricGaussianKL(Energy):
@property
def samples(self):
if self._comm is not None:
res = self._comm.allgather(self._samples)
res = tuple(item for sublist in res for item in sublist)
if self._comm is None:
for s in self._local_samples:
yield s
if self._mirror_samples:
yield -s
else:
res = self._samples
if self._mirror_samples:
res = res + tuple(-item for item in res)
return res
ntask = self._comm.Get_size()
rank = self._comm.Get_rank()
rank_lo_hi = [_shareRange(self._n_samples, ntask, i) for i in range(ntask)]
for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
data = self._local_samples[i-self._lo] if rank == itask else None
s = self._comm.bcast(data, root=itask)
yield s
if self._mirror_samples:
yield -s
def _metric_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
......@@ -263,7 +271,7 @@ class MetricGaussianKL(Energy):
lin = self._lin.with_want_metric()
samp = full(self._hamiltonian.domain, 0.)
sseq = random.spawn_sseq(self._n_samples)
for i, v in enumerate(self._samples):
for i, v in enumerate(self._local_samples):
random.push_sseq(sseq[self._lo+i])
samp = samp + self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False, dtype=dtype)
if self._mirror_samples:
......
......@@ -45,13 +45,13 @@ def test_kl(constants, point_estimates, mirror_samples):
point_estimates=point_estimates,
mirror_samples=mirror_samples,
napprox=0)
samp_full = kl.samples
locsamp = kl._local_samples
klpure = ift.MetricGaussianKL(mean0,
h,
len(samp_full),
mirror_samples=False,
nsamps,
mirror_samples=mirror_samples,
napprox=0,
_samples=samp_full)
_local_samples=locsamp)
# Test value
assert_allclose(kl.value, klpure.value)
......@@ -66,7 +66,7 @@ def test_kl(constants, point_estimates, mirror_samples):
# Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(kl.samples) == expected_nsamps)
assert_(len(tuple(kl.samples)) == expected_nsamps)
# Test point_estimates (after drawing samples)
for kk in point_estimates:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment