Commit 9a7ac57b authored by Martin Reinecke's avatar Martin Reinecke
Browse files

make more functions private

parent 34ed0cd4
Pipeline #71238 passed with stages
in 17 minutes and 7 seconds
...@@ -36,7 +36,7 @@ def _shareRange(nwork, nshares, myshare): ...@@ -36,7 +36,7 @@ def _shareRange(nwork, nshares, myshare):
return lo, hi return lo, hi
def np_allreduce_sum(comm, arr): def _np_allreduce_sum(comm, arr):
if comm is None: if comm is None:
return arr return arr
from mpi4py import MPI from mpi4py import MPI
...@@ -46,18 +46,18 @@ def np_allreduce_sum(comm, arr): ...@@ -46,18 +46,18 @@ def np_allreduce_sum(comm, arr):
return res return res
def allreduce_sum_field(comm, fld): def _allreduce_sum_field(comm, fld):
if comm is None: if comm is None:
return fld return fld
if isinstance(fld, Field): if isinstance(fld, Field):
return Field(fld.domain, np_allreduce_sum(fld.val)) return Field(fld.domain, _np_allreduce_sum(fld.val))
res = tuple( res = tuple(
Field(f.domain, np_allreduce_sum(comm, f.val)) Field(f.domain, _np_allreduce_sum(comm, f.val))
for f in fld.values()) for f in fld.values())
return MultiField(fld.domain, res) return MultiField(fld.domain, res)
class KLMetric(EndomorphicOperator): class _KLMetric(EndomorphicOperator):
def __init__(self, KL): def __init__(self, KL):
self._KL = KL self._KL = KL
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
...@@ -68,7 +68,7 @@ class KLMetric(EndomorphicOperator): ...@@ -68,7 +68,7 @@ class KLMetric(EndomorphicOperator):
return self._KL.apply_metric(x) return self._KL.apply_metric(x)
def draw_sample(self, from_inverse=False, dtype=np.float64): def draw_sample(self, from_inverse=False, dtype=np.float64):
return self._KL.metric_sample(from_inverse, dtype) return self._KL._metric_sample(from_inverse, dtype)
class MetricGaussianKL(Energy): class MetricGaussianKL(Energy):
...@@ -207,8 +207,8 @@ class MetricGaussianKL(Energy): ...@@ -207,8 +207,8 @@ class MetricGaussianKL(Energy):
else: else:
v += tmp.val.val v += tmp.val.val
g = g + tmp.gradient g = g + tmp.gradient
self._val = np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples self._val = _np_allreduce_sum(self._comm, v)[()] / self._n_eff_samples
self._grad = allreduce_sum_field(self._comm, g) / self._n_eff_samples self._grad = _allreduce_sum_field(self._comm, g) / self._n_eff_samples
self._metric = None self._metric = None
self._sampdt = lh_sampling_dtype self._sampdt = lh_sampling_dtype
...@@ -239,11 +239,11 @@ class MetricGaussianKL(Energy): ...@@ -239,11 +239,11 @@ class MetricGaussianKL(Energy):
def apply_metric(self, x): def apply_metric(self, x):
self._get_metric() self._get_metric()
return allreduce_sum_field(self._comm, self._metric(x)) return _allreduce_sum_field(self._comm, self._metric(x))
@property @property
def metric(self): def metric(self):
return KLMetric(self) return _KLMetric(self)
@property @property
def samples(self): def samples(self):
...@@ -256,7 +256,7 @@ class MetricGaussianKL(Energy): ...@@ -256,7 +256,7 @@ class MetricGaussianKL(Energy):
res = res + tuple(-item for item in res) res = res + tuple(-item for item in res)
return res return res
def unscaled_metric_sample(self, from_inverse=False, dtype=np.float64): def _unscaled_metric_sample(self, from_inverse=False, dtype=np.float64):
if from_inverse: if from_inverse:
raise NotImplementedError() raise NotImplementedError()
lin = self._lin.with_want_metric() lin = self._lin.with_want_metric()
...@@ -268,7 +268,7 @@ class MetricGaussianKL(Energy): ...@@ -268,7 +268,7 @@ class MetricGaussianKL(Energy):
if self._mirror_samples: if self._mirror_samples:
samp = samp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False, dtype=dtype) samp = samp + self._hamiltonian(lin-v).metric.draw_sample(from_inverse=False, dtype=dtype)
random.pop_sseq() random.pop_sseq()
return allreduce_sum_field(self._comm, samp) return _allreduce_sum_field(self._comm, samp)
def metric_sample(self, from_inverse=False, dtype=np.float64): def _metric_sample(self, from_inverse=False, dtype=np.float64):
return self.unscaled_metric_sample(from_inverse, dtype)/self._n_eff_samples return self._unscaled_metric_sample(from_inverse, dtype)/self._n_eff_samples
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment