Commit 91c95868 authored by Philipp Frank's avatar Philipp Frank
Browse files

restructure fcvi

parent 28920e82
......@@ -73,11 +73,11 @@ if __name__ == "__main__":
H = ift.StandardHamiltonian(likelihood)
position_fc = ift.from_random(H.domain)*0.1
position_mf = ift.from_random(H.domain)*0.
position_mf = ift.from_random(H.domain)*0.1
fc = ift.FullCovarianceVI(position_fc, H, 3, True, initial_sig=0.01)
mf = ift.MeanFieldVI(position_mf, H, 3, True, initial_sig=0.0001)
minimizer_fc = ift.ADVIOptimizer(10)
mf = ift.MeanFieldVI(position_mf, H, 3, True, initial_sig=0.01)
minimizer_fc = ift.ADVIOptimizer(20, eta = 0.1)
minimizer_mf = ift.ADVIOptimizer(10)
plt.pause(0.001)
......@@ -89,7 +89,7 @@ if __name__ == "__main__":
plt.figure("result")
plt.cla()
plt.plot(
sky(fc.position).val,
sky(fc.mean).val,
"b-",
label="Full covariance",
)
......@@ -98,14 +98,11 @@ if __name__ == "__main__":
)
for i in range(5):
plt.plot(
sky(mf.draw_sample()).val, "b-", alpha=0.3
sky(fc.draw_sample()).val, "b-", alpha=0.3
)
plt.plot(
sky(mf.draw_sample()).val, "r-", alpha=0.3
)
#for samp in KL_mf.samples:
# plt.plot(
# sky(meanfield_model.generator(KL_mf.position + samp)).val,
# "r-",
# alpha=0.3,
# )
plt.plot(data.val, "kx")
plt.plot(sky(mock_position).val, "k-", label="Ground truth")
plt.legend()
......
......@@ -27,8 +27,8 @@ from ..operators.energy_operators import EnergyOperator
from ..operators.linear_operator import LinearOperator
from ..operators.multifield2vector import Multifield2Vector
from ..operators.sandwich_operator import SandwichOperator
from ..operators.simple_linear_operators import FieldAdapter, PartialExtractor
from ..sugar import domain_union, full, makeField, from_random, is_fieldlike
from ..operators.simple_linear_operators import FieldAdapter
from ..sugar import full, makeField, from_random, is_fieldlike
from ..minimization.energy_adapter import StochasticEnergyAdapter
......@@ -80,46 +80,52 @@ class MeanFieldVI:
class FullCovarianceVI:
def __init__(self, position, hamiltonian, n_samples, mirror_samples,
initial_sig=1, comm=None, nanisinf=False, names = ['mean', 'cov']):
initial_sig=1, comm=None, nanisinf=False):
"""Collect the operators required for Gaussian full-covariance variational
inference.
"""
Flat = Multifield2Vector(position.domain)
one_space = UnstructuredDomain(1)
flat_domain = Flat.target[0]
N_tri = flat_domain.shape[0]*(flat_domain.shape[0]+1)//2
triangular_space = DomainTuple.make(UnstructuredDomain(N_tri))
tri = FieldAdapter(triangular_space, names[1])
tri = FieldAdapter(triangular_space, 'cov')
mat_space = DomainTuple.make((flat_domain,flat_domain))
lat_mat_space = DomainTuple.make((one_space,flat_domain))
lat = FieldAdapter(lat_mat_space,'latent')
lat = FieldAdapter(Flat.target,'latent')
LT = LowerTriangularProjector(triangular_space,mat_space)
mean = FieldAdapter(flat_domain,names[0])
mean = FieldAdapter(flat_domain,'mean')
cov = LT @ tri
co = FieldAdapter(cov.target, 'co')
matmul_setup_dom = domain_union((co.domain,lat.domain))
co_part = PartialExtractor(matmul_setup_dom, co.domain)
lat_part = PartialExtractor(matmul_setup_dom, lat.domain)
matmul_setup = lat_part.adjoint @ lat.adjoint @ lat + co_part.adjoint @ co.adjoint @ cov
MatMult = MultiLinearEinsum(matmul_setup.target,'ij,ki->jk', key_order=('co','latent'))
Resp = Respacer(MatMult.target, mean.target)
generator = Flat.adjoint @ (mean + Resp @ MatMult @ matmul_setup)
matmul_setup = lat.adjoint @ lat + cov.ducktape_left('co')
MatMult = MultiLinearEinsum(matmul_setup.target,'ij,j->i',
key_order=('co','latent'))
self._generator = Flat.adjoint @ (mean + MatMult @ matmul_setup)
Diag = DiagonalSelector(cov.target, Flat.target)
diag_cov = Diag(cov).absolute()
entropy = GaussianEntropy(diag_cov.target) @ diag_cov
diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig))[np.tril_indices(flat_domain.shape[0])]
pos = MultiField.from_dict({names[0]:Flat(position),names[1]:makeField(generator.domain[names[1]], diag_tri)})
op = hamiltonian(generator) + entropy
self._names = names
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples, mirror_samples, nanisinf=nanisinf, comm=comm)
self._Flat = Flat
self._entropy = GaussianEntropy(diag_cov.target) @ diag_cov
diag_tri = np.diag(np.full(flat_domain.shape[0], initial_sig))
diag_tri = diag_tri[np.tril_indices(flat_domain.shape[0])]
pos = MultiField.from_dict(
{'mean':Flat(position),
'cov':makeField(triangular_space, diag_tri)})
op = hamiltonian(self._generator) + self._entropy
self._KL = StochasticEnergyAdapter.make(pos, op, ['latent',], n_samples,
mirror_samples, nanisinf=nanisinf, comm=comm)
self._mean = Flat.adjoint @ mean
self._samdom = lat.domain
@property
def mean(self):
return self._mean.force(self._KL.position)
@property
def position(self):
return self._Flat.adjoint(self._KL.position[self._names[0]])
def entropy(self):
return self._entropy.force(self._KL.position)
def draw_sample(self):
_, op = self._generator.simplify_for_constant_input(
from_random(self._samdom))
return op(self._KL.position)
def minimize(self, minimizer):
self._KL, _ = minimizer(self._KL)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment