Commit 28920e82 authored by Philipp Frank's avatar Philipp Frank
Browse files

restructure mfvi

parent f30d1547
Pipeline #102752 passed with stages
in 13 minutes and 42 seconds
......@@ -75,8 +75,8 @@ if __name__ == "__main__":
position_fc = ift.from_random(H.domain)*0.1
position_mf = ift.from_random(H.domain)*0.
fc = ift.FullCovariance(position_fc, H, 3, True, initial_sig=0.01)
mf = ift.MeanField(position_mf, H, 3, True, initial_sig=0.0001)
fc = ift.FullCovarianceVI(position_fc, H, 3, True, initial_sig=0.01)
mf = ift.MeanFieldVI(position_mf, H, 3, True, initial_sig=0.0001)
minimizer_fc = ift.ADVIOptimizer(10)
minimizer_mf = ift.ADVIOptimizer(10)
......@@ -94,12 +94,12 @@ if __name__ == "__main__":
label="Full covariance",
)
plt.plot(
sky(mf.position).val, "r-", label="Mean field"
sky(mf.mean).val, "r-", label="Mean field"
)
#for samp in KL_fc.samples:
# plt.plot(
# sky(fullcov_model.generator(KL_fc.position + samp)).val, "b-", alpha=0.3
# )
for i in range(5):
plt.plot(
sky(mf.draw_sample()).val, "b-", alpha=0.3
)
#for samp in KL_mf.samples:
# plt.plot(
# sky(meanfield_model.generator(KL_mf.position + samp)).val,
......
......@@ -92,7 +92,7 @@ from .library.adjust_variances import (make_adjust_variances_hamiltonian,
from .library.nft import Gridder, FinuFFT
from .library.correlated_fields import CorrelatedFieldMaker
from .library.correlated_fields_simple import SimpleCorrelatedField
from .library.variational_models import MeanField, FullCovariance
from .library.variational_models import MeanFieldVI, FullCovarianceVI
from . import extra
......
......@@ -28,42 +28,57 @@ from ..operators.linear_operator import LinearOperator
from ..operators.multifield2vector import Multifield2Vector
from ..operators.sandwich_operator import SandwichOperator
from ..operators.simple_linear_operators import FieldAdapter, PartialExtractor
from ..sugar import domain_union, full, makeField, is_fieldlike
from ..minimization.stochastic_minimizer import PartialSampledEnergy
from ..sugar import domain_union, full, makeField, from_random, is_fieldlike
from ..minimization.energy_adapter import StochasticEnergyAdapter
class MeanField:
def __init__(self, position, hamiltonian, n_samples, mirror_samples,
initial_sig=1, comm=None, nanisinf=False, names = ['mean', 'var']):
class MeanFieldVI:
def __init__(self, initial_position, hamiltonian, n_samples, mirror_samples,
initial_sig=1, comm=None, nanisinf=False):
"""Collect the operators required for Gaussian mean-field variational
inference.
"""
Flat = Multifield2Vector(position.domain)
std = FieldAdapter(Flat.target, names[1]).absolute()
Flat = Multifield2Vector(initial_position.domain)
self._std = FieldAdapter(Flat.target, 'std').absolute()
latent = FieldAdapter(Flat.target,'latent')
mean = FieldAdapter(Flat.target, names[0])
generator = Flat.adjoint(mean + std * latent)
entropy = GaussianEntropy(std.target) @ std
pos = {names[0]: Flat(position)}
self._mean = FieldAdapter(Flat.target, 'mean')
self._generator = Flat.adjoint(self._mean + self._std * latent)
self._entropy = GaussianEntropy(self._std.target) @ self._std
self._mean = Flat.adjoint @ self._mean
self._std = Flat.adjoint @ self._std
pos = {'mean': Flat(initial_position)}
if is_fieldlike(initial_sig):
pos[names[1]] = Flat(initial_sig)
pos['std'] = Flat(initial_sig)
else:
pos[names[1]] = full(Flat.target, initial_sig)
pos['std'] = full(Flat.target, initial_sig)
pos = MultiField.from_dict(pos)
op = hamiltonian(generator) + entropy
self._names = names
self._KL = PartialSampledEnergy.make(pos, op, ['latent',], n_samples, mirror_samples, nanisinf=nanisinf, comm=comm)
self._Flat = Flat
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
self._samdom = latent.domain
@property
def position(self):
return self._Flat.adjoint(self._KL.position[self._names[0]])
def mean(self):
return self._mean.force(self._KL.position)
@property
def std(self):
return self._std.force(self._KL.position)
@property
def entropy(self):
return self._entropy.force(self._KL.position)
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
from_random(self._samdom))
return op(self._KL.position)
def minimize(self, minimizer):
self._KL, _ = minimizer(self._KL)
class FullCovariance:
class FullCovarianceVI:
def __init__(self, position, hamiltonian, n_samples, mirror_samples,
initial_sig=1, comm=None, nanisinf=False, names = ['mean', 'cov']):
"""Collect the operators required for Gaussian full-covariance variational
......@@ -99,7 +114,7 @@ class FullCovariance:
pos = MultiField.from_dict({names[0]:Flat(position),names[1]:makeField(generator.domain[names[1]], diag_tri)})
op = hamiltonian(generator) + entropy
self._names = names
self._KL = PartialSampledEnergy.make(pos, op, ['latent',], n_samples, mirror_samples, nanisinf=nanisinf, comm=comm)
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples, mirror_samples, nanisinf=nanisinf, comm=comm)
self._Flat = Flat
@property
......
......@@ -23,7 +23,6 @@ from ..minimization.energy import Energy
from ..utilities import myassert, allreduce_sum
from ..multi_domain import MultiDomain
from ..sugar import from_random
from .kl_energies import _SelfAdjointOperatorWrapper, _get_lo_hi
class EnergyAdapter(Energy):
"""Helper class which provides the traditional Nifty Energy interface to
......@@ -91,7 +90,7 @@ class EnergyAdapter(Energy):
class StochasticEnergyAdapter(Energy):
def __init__(self, position, op, keys, local_ops, n_samples, comm, nanisinf,
def __init__(self, position, bigop, keys, local_ops, n_samples, comm, nanisinf,
_callingfrommake=False):
if not _callingfrommake:
raise NotImplementedError
......@@ -112,7 +111,7 @@ class StochasticEnergyAdapter(Energy):
self._val = np.inf
self._grad = allreduce_sum(g, self._comm)/self._n_samples
self._op = op
self._op = bigop
self._keys = keys
@property
......@@ -136,6 +135,7 @@ class StochasticEnergyAdapter(Energy):
@property
def metric(self):
from .kl_energies import _SelfAdjointOperatorWrapper
return _SelfAdjointOperatorWrapper(self.position.domain,
self.apply_metric)
......@@ -158,6 +158,7 @@ class StochasticEnergyAdapter(Energy):
samdom = MultiDomain.make(samdom)
local_ops = []
sseq = random.spawn_sseq(n_samples)
from .kl_energies import _get_lo_hi
for i in range(*_get_lo_hi(comm, n_samples)):
with random.Context(sseq[i]):
rnd = from_random(samdom)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment