diff --git a/src/library/variational_models.py b/src/library/variational_models.py index ac405aba94c45bef30994d7b4676d7f18484e3fc..7931f49a8c9aa471c344e5191899c4cd142c83a0 100644 --- a/src/library/variational_models.py +++ b/src/library/variational_models.py @@ -55,8 +55,7 @@ class FullCovarianceModel(): co_part = PartialExtractor(matmul_setup_dom, co.domain) lat_part = PartialExtractor(matmul_setup_dom, lat.domain) matmul_setup = lat_part.adjoint @ lat.adjoint @ lat + co_part.adjoint @ co.adjoint @ cov - breakpoint() - MatMult = MultiLinearEinsum(matmul_setup.domain,'ij,ki->jk', key_order=('co','latent')) + MatMult = MultiLinearEinsum(matmul_setup.target,'ij,ki->jk', key_order=('co','latent')) Resp = Respacer(MatMult.target, mean.target) self.generator = self.Flat.adjoint @ (mean + Resp @ MatMult @ matmul_setup)