Commit 193a276f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'KL_sample_generator' into 'NIFTy_6'

Use a generator for MetricGaussianKL.samples

See merge request !432
parents c01d10cc 8e377ee6
Pipeline #71475 passed with stages
in 27 minutes and 46 seconds
...@@ -121,7 +121,7 @@ class MetricGaussianKL(Energy): ...@@ -121,7 +121,7 @@ class MetricGaussianKL(Energy):
the presence of this parameter is that metric of the likelihood energy the presence of this parameter is that metric of the likelihood energy
is just an `Operator` which does not know anything about the dtype of is just an `Operator` which does not know anything about the dtype of
the fields on which it acts. Default is float64. the fields on which it acts. Default is float64.
_samples : None _local_samples : None
Only a parameter for internal uses. Typically not to be set by users. Only a parameter for internal uses. Typically not to be set by users.
Note Note
...@@ -138,7 +138,7 @@ class MetricGaussianKL(Energy): ...@@ -138,7 +138,7 @@ class MetricGaussianKL(Energy):
def __init__(self, mean, hamiltonian, n_samples, constants=[], def __init__(self, mean, hamiltonian, n_samples, constants=[],
point_estimates=[], mirror_samples=False, point_estimates=[], mirror_samples=False,
napprox=0, comm=None, _samples=None, napprox=0, comm=None, _local_samples=None,
lh_sampling_dtype=np.float64): lh_sampling_dtype=np.float64):
super(MetricGaussianKL, self).__init__(mean) super(MetricGaussianKL, self).__init__(mean)
...@@ -170,31 +170,31 @@ class MetricGaussianKL(Energy): ...@@ -170,31 +170,31 @@ class MetricGaussianKL(Energy):
if self._mirror_samples: if self._mirror_samples:
self._n_eff_samples *= 2 self._n_eff_samples *= 2
if _samples is None: if _local_samples is None:
met = hamiltonian(Linearization.make_partial_var( met = hamiltonian(Linearization.make_partial_var(
mean, self._point_estimates, True)).metric mean, self._point_estimates, True)).metric
if napprox >= 1: if napprox >= 1:
met._approximation = makeOp(approximation2endo(met, napprox)) met._approximation = makeOp(approximation2endo(met, napprox))
_samples = [] _local_samples = []
sseq = random.spawn_sseq(self._n_samples) sseq = random.spawn_sseq(self._n_samples)
for i in range(self._lo, self._hi): for i in range(self._lo, self._hi):
random.push_sseq(sseq[i]) random.push_sseq(sseq[i])
_samples.append(met.draw_sample(from_inverse=True, _local_samples.append(met.draw_sample(from_inverse=True,
dtype=lh_sampling_dtype)) dtype=lh_sampling_dtype))
random.pop_sseq() random.pop_sseq()
_samples = tuple(_samples) _local_samples = tuple(_local_samples)
else: else:
if len(_samples) != self._hi-self._lo: if len(_local_samples) != self._hi-self._lo:
raise ValueError("# of samples mismatch") raise ValueError("# of samples mismatch")
self._samples = _samples self._local_samples = _local_samples
self._lin = Linearization.make_partial_var(mean, self._constants) self._lin = Linearization.make_partial_var(mean, self._constants)
v, g = None, None v, g = None, None
if len(self._samples) == 0: # hack if there are too many MPI tasks if len(self._local_samples) == 0: # hack if there are too many MPI tasks
tmp = self._hamiltonian(self._lin) tmp = self._hamiltonian(self._lin)
v = 0. * tmp.val.val v = 0. * tmp.val.val
g = 0. * tmp.gradient g = 0. * tmp.gradient
else: else:
for s in self._samples: for s in self._local_samples:
tmp = self._hamiltonian(self._lin+s) tmp = self._hamiltonian(self._lin+s)
if self._mirror_samples: if self._mirror_samples:
tmp = tmp + self._hamiltonian(self._lin-s) tmp = tmp + self._hamiltonian(self._lin-s)
...@@ -213,7 +213,7 @@ class MetricGaussianKL(Energy): ...@@ -213,7 +213,7 @@ class MetricGaussianKL(Energy):
return MetricGaussianKL( return MetricGaussianKL(
position, self._hamiltonian, self._n_samples, self._constants, position, self._hamiltonian, self._n_samples, self._constants,
self._point_estimates, self._mirror_samples, comm=self._comm, self._point_estimates, self._mirror_samples, comm=self._comm,
_samples=self._samples, lh_sampling_dtype=self._sampdt) _local_samples=self._local_samples, lh_sampling_dtype=self._sampdt)
@property @property
def value(self): def value(self):
...@@ -226,15 +226,15 @@ class MetricGaussianKL(Energy): ...@@ -226,15 +226,15 @@ class MetricGaussianKL(Energy):
def _get_metric(self): def _get_metric(self):
lin = self._lin.with_want_metric() lin = self._lin.with_want_metric()
if self._metric is None: if self._metric is None:
if len(self._samples) == 0: # hack if there are too many MPI tasks if len(self._local_samples) == 0: # hack if there are too many MPI tasks
self._metric = self._hamiltonian(lin).metric.scale(0.) self._metric = self._hamiltonian(lin).metric.scale(0.)
else: else:
mymap = map(lambda v: self._hamiltonian(lin+v).metric, mymap = map(lambda v: self._hamiltonian(lin+v).metric,
self._samples) self._local_samples)
unscaled_metric = utilities.my_sum(mymap) unscaled_metric = utilities.my_sum(mymap)
if self._mirror_samples: if self._mirror_samples:
mymap = map(lambda v: self._hamiltonian(lin-v).metric, mymap = map(lambda v: self._hamiltonian(lin-v).metric,
self._samples) self._local_samples)
unscaled_metric = unscaled_metric + utilities.my_sum(mymap) unscaled_metric = unscaled_metric + utilities.my_sum(mymap)
self._metric = unscaled_metric.scale(1./self._n_eff_samples) self._metric = unscaled_metric.scale(1./self._n_eff_samples)
...@@ -248,14 +248,22 @@ class MetricGaussianKL(Energy): ...@@ -248,14 +248,22 @@ class MetricGaussianKL(Energy):
@property @property
def samples(self): def samples(self):
if self._comm is not None: if self._comm is None:
res = self._comm.allgather(self._samples) for s in self._local_samples:
res = tuple(item for sublist in res for item in sublist) yield s
if self._mirror_samples:
yield -s
else: else:
res = self._samples ntask = self._comm.Get_size()
if self._mirror_samples: rank = self._comm.Get_rank()
res = res + tuple(-item for item in res) rank_lo_hi = [_shareRange(self._n_samples, ntask, i) for i in range(ntask)]
return res for itask, (l, h) in enumerate(rank_lo_hi):
for i in range(l, h):
data = self._local_samples[i-self._lo] if rank == itask else None
s = self._comm.bcast(data, root=itask)
yield s
if self._mirror_samples:
yield -s
def _metric_sample(self, from_inverse=False, dtype=np.float64): def _metric_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse: if from_inverse:
...@@ -263,7 +271,7 @@ class MetricGaussianKL(Energy): ...@@ -263,7 +271,7 @@ class MetricGaussianKL(Energy):
lin = self._lin.with_want_metric() lin = self._lin.with_want_metric()
samp = full(self._hamiltonian.domain, 0.) samp = full(self._hamiltonian.domain, 0.)
sseq = random.spawn_sseq(self._n_samples) sseq = random.spawn_sseq(self._n_samples)
for i, v in enumerate(self._samples): for i, v in enumerate(self._local_samples):
random.push_sseq(sseq[self._lo+i]) random.push_sseq(sseq[self._lo+i])
samp = samp + self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False, dtype=dtype) samp = samp + self._hamiltonian(lin+v).metric.draw_sample(from_inverse=False, dtype=dtype)
if self._mirror_samples: if self._mirror_samples:
......
...@@ -45,13 +45,13 @@ def test_kl(constants, point_estimates, mirror_samples): ...@@ -45,13 +45,13 @@ def test_kl(constants, point_estimates, mirror_samples):
point_estimates=point_estimates, point_estimates=point_estimates,
mirror_samples=mirror_samples, mirror_samples=mirror_samples,
napprox=0) napprox=0)
samp_full = kl.samples locsamp = kl._local_samples
klpure = ift.MetricGaussianKL(mean0, klpure = ift.MetricGaussianKL(mean0,
h, h,
len(samp_full), nsamps,
mirror_samples=False, mirror_samples=mirror_samples,
napprox=0, napprox=0,
_samples=samp_full) _local_samples=locsamp)
# Test value # Test value
assert_allclose(kl.value, klpure.value) assert_allclose(kl.value, klpure.value)
...@@ -66,7 +66,7 @@ def test_kl(constants, point_estimates, mirror_samples): ...@@ -66,7 +66,7 @@ def test_kl(constants, point_estimates, mirror_samples):
# Test number of samples # Test number of samples
expected_nsamps = 2*nsamps if mirror_samples else nsamps expected_nsamps = 2*nsamps if mirror_samples else nsamps
assert_(len(kl.samples) == expected_nsamps) assert_(len(tuple(kl.samples)) == expected_nsamps)
# Test point_estimates (after drawing samples) # Test point_estimates (after drawing samples)
for kk in point_estimates: for kk in point_estimates:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment