Commit 9a7ac57b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

make more functions private

parent 34ed0cd4
Pipeline #71238 passed with stages
in 17 minutes and 7 seconds
......@@ -36,7 +36,7 @@ def _shareRange(nwork, nshares, myshare):
return lo, hi
def np_allreduce_sum(comm, arr):
def _np_allreduce_sum(comm, arr):
if comm is None:
return arr
from mpi4py import MPI
......@@ -46,18 +46,18 @@ def np_allreduce_sum(comm, arr):
return res
def allreduce_sum_field(comm, fld):
def _allreduce_sum_field(comm, fld):
if comm is None:
return fld
if isinstance(fld, Field):
return Field(fld.domain, np_allreduce_sum(fld.val))
return Field(fld.domain, _np_allreduce_sum(fld.val))
res = tuple(
Field(f.domain, np_allreduce_sum(comm, f.val))
Field(f.domain, _np_allreduce_sum(comm, f.val))
for f in fld.values())
return MultiField(fld.domain, res)
class KLMetric(EndomorphicOperator):
class _KLMetric(EndomorphicOperator):
def __init__(self, KL):
self._KL = KL
self._capability = self.TIMES | self.ADJOINT_TIMES
......@@ -68,7 +68,7 @@ class KLMetric(EndomorphicOperator):
return self._KL.apply_metric(x)
def draw_sample(self, from_inverse=False, dtype=np.float64):
return self._KL.metric_sample(from_inverse, dtype)
return self._KL._metric_sample(from_inverse, dtype)
class MetricGaussianKL(Energy):
......@@ -207,8 +207,8 @@ class MetricGaussianKL(Energy):
v += tmp.val.val
g = g + tmp.gradient
self._val = np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
self._grad = allreduce_sum_field(self._comm, g) / self._n_eff_samples
self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
self._metric = None
self._sampdt = lh_sampling_dtype
......@@ -239,11 +239,11 @@ class MetricGaussianKL(Energy):
def apply_metric(self, x):
return allreduce_sum_field(self._comm, self._metric(x))
return _allreduce_sum_field(self._comm, self._metric(x))
def metric(self):
return KLMetric(self)
return _KLMetric(self)
def samples(self):
......@@ -256,7 +256,7 @@ class MetricGaussianKL(Energy):
res = res + tuple(-item for item in res)
return res
def unscaled_metric_sample(self, from_inverse=False, dtype=np.float64):
def _unscaled_metric_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse:
raise NotImplementedError()
lin = self._lin.with_want_metric()
......@@ -268,7 +268,7 @@ class MetricGaussianKL(Energy):
if self._mirror_samples:
samp = samp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False, dtype=dtype)
return allreduce_sum_field(self._comm, samp)
return _allreduce_sum_field(self._comm, samp)
def metric_sample(self, from_inverse=False, dtype=np.float64):
return self.unscaled_metric_sample(from_inverse, dtype)/self._n_eff_samples
def _metric_sample(self, from_inverse=False, dtype=np.float64):
return self._unscaled_metric_sample(from_inverse, dtype)/self._n_eff_samples
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment