Commit 317ba8d0 authored by Jakob Knollmüller's avatar Jakob Knollmüller
Browse files

add fullcovariance

parent d6982524
Pipeline #94601 failed with stages
in 16 minutes and 30 seconds
......@@ -90,7 +90,7 @@ from .library.adjust_variances import (make_adjust_variances_hamiltonian,
from .library.gridder import Gridder
from .library.correlated_fields import CorrelatedFieldMaker
from .library.correlated_fields_simple import SimpleCorrelatedField
from .library.variational_models import MeanfieldModel
from .library.variational_models import MeanfieldModel, FullCovarianceModel
from . import extra
......
import numpy as np
from ..operators.multifield_flattener import MultifieldFlattener
from ..operators.simple_linear_operators import FieldAdapter
from ..operators.simple_linear_operators import FieldAdapter, PartialExtractor
from ..operators.energy_operators import EnergyOperator
from ..operators.sandwich_operator import SandwichOperator
from ..sugar import full, from_random
from ..operators.linear_operator import LinearOperator
from ..operators.einsum import MultiLinearEinsum
from ..sugar import full, from_random, makeField, domain_union
from ..linearization import Linearization
from ..field import Field
from ..multi_field import MultiField
from ..domain_tuple import DomainTuple
from ..domains.unstructured_domain import UnstructuredDomain
class MeanfieldModel():
def __init__(self, domain):
......@@ -30,6 +34,48 @@ class MeanfieldModel():
initial_pos['mean'] = self.Flat(initial_mean)
return MultiField.from_dict(initial_pos)
class FullCovarianceModel():
def __init__(self, domain):
self.domain = domain
self.Flat = MultifieldFlattener(self.domain)
one_space = UnstructuredDomain(1)
self.latent_domain = self.Flat.target[0]
N_tri = self.latent_domain.shape[0]*(self.latent_domain.shape[0]+1)//2
triangular_space = DomainTuple.make(UnstructuredDomain(N_tri))
tri = FieldAdapter(triangular_space, 'cov')
mat_space = DomainTuple.make((self.latent_domain,self.latent_domain))
lat_mat_space = DomainTuple.make((one_space,self.latent_domain))
lat = FieldAdapter(lat_mat_space,'latent')
LT = LowerTriangularProjector(triangular_space,mat_space)
mean = FieldAdapter(self.latent_domain,'mean')
cov = LT @ tri
co = FieldAdapter(cov.target, 'co')
matmul_setup_dom = domain_union((co.domain,lat.domain))
co_part = PartialExtractor(matmul_setup_dom, co.domain)
lat_part = PartialExtractor(matmul_setup_dom, lat.domain)
matmul_setup = lat_part.adjoint @ lat.adjoint @ lat + co_part.adjoint @ co.adjoint @ cov
MatMult = MultiLinearEinsum(matmul_setup.domain,'ij,ki->jk', key_order=('co','latent'))
Resp = Respacer(MatMult.target, mean.target)
self.generator = self.Flat.adjoint @ (mean + Resp @ MatMult @ matmul_setup)
Diag = DiagonalSelector(cov.target, self.Flat.target)
diag_cov = Diag(cov).absolute()
self.entropy = GaussianEntropy(diag_cov.target) @ diag_cov
def get_initial_pos(self, initial_mean = None):
initial_pos = from_random(self.generator.domain).to_dict()
initial_pos['latent'] = full(self.latent_domain.domain['latent'], 0.)
diag_tri = np.diag(np.ones(self.latent_domain.shape[0]))[np.tril_indices(self.latent_domain.shape[0])]
initial_pos['cov'] = makeField(self.generator.domain['cov'], diag_tri)
if initial_mean is None:
initial_mean = 0.1*from_random(self.generator.target)
initial_pos['mean'] = self.Flat(initial_mean)
return MultiField.from_dict(initial_pos)
class GaussianEntropy(EnergyOperator):
def __init__(self, domain):
......@@ -42,4 +88,47 @@ class GaussianEntropy(EnergyOperator):
return Field.scalar(res)
if not x.want_metric:
return res
return res.add_metric(SandwichOperator.make(res.jac)) #FIXME not sure about metric
\ No newline at end of file
return res.add_metric(SandwichOperator.make(res.jac)) #FIXME not sure about metric
class LowerTriangularProjector(LinearOperator):
def __init__(self, domain, target):
self._domain = domain
self._target = target
self._indices=np.tril_indices(target.shape[1])
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_mode(mode)
if mode == self.TIMES:
mat = np.zeros(self._target.shape[1:])
mat[self._indices] = x.val
return makeField(self._target,mat.reshape((1,)+mat.shape))
return makeField(self._domain, x.val[0][self._indices].reshape(self._domain.shape))
class DiagonalSelector(LinearOperator):
def __init__(self, domain, target):
self._domain = domain
self._target = target
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
self._check_mode(mode)
if mode == self.TIMES:
result = np.diag(x.val[0])
return makeField(self._target,result)
return makeField(self._domain,np.diag(x.val).reshape(self._domain.shape))
class Respacer(LinearOperator):
def __init__(self, domain, target):
self._domain = domain
self._target = target
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self,x,mode):
self._check_mode(mode)
if mode == self.TIMES:
return makeField(self._target,x.val.reshape(self._target.shape))
return makeField(self._domain,x.val.reshape(self._domain.shape))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment