Commit e635e62c authored by Jakob Knollmüller's avatar Jakob Knollmüller
Browse files

KL sugar

parent 15d11511
Pipeline #94912 failed with stages
in 16 minutes and 23 seconds
......@@ -24,7 +24,7 @@ from ..multi_field import MultiField
from ..operators.endomorphic_operator import EndomorphicOperator
from ..operators.energy_operators import StandardHamiltonian
from ..operators.multifield_flattener import MultifieldFlattener
from ..probing import approximation2endo
from ..probing import approximation2endo, StatCalculator
from ..sugar import makeOp, full, from_random
from .energy import Energy
......@@ -262,6 +262,24 @@ class MetricGaussianKL(Energy):
yield s
if self._mirror_samples:
yield -s
@property
def distribution_samples(self):
dist_samps = []
for sample in self.samples:
dist_samps.append(self.position + sample)
return dist_samps
def estimate_quantity(self, op, averaged=False): # Maybe have StatCalculator also store
result = [] # samples and always return ?
for sample in self.distribution_samples:
result.append(op.force(sample))
if averaged:
sc = StatCalculator()
for samp in result:
sc.add(samp)
return sc
return result
def _metric_sample(self, from_inverse=False):
if from_inverse:
......@@ -434,3 +452,21 @@ class ParametricGaussianKL(Energy):
yield s
if self._mirror_samples:
yield -s
@property
def distribution_samples(self):
dist_samps = []
for sample in self.samples:
dist_samps.append(self._variational_model.generator(self.position + sample))
return dist_samps
def estimate_quantity(self, op, averaged=False):
result = []
for sample in self.distribution_samples:
result.append(op.force(sample))
if averaged:
sc = StatCalculator()
for samp in result:
sc.add(samp)
return sc
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment