diff --git a/demos/meanfield_demo.py b/demos/meanfield_demo.py index bc9473d73807de986e164006fa29ea80e2c9123f..25f85a72cef0ddf4407c54215474016ba9d37b6b 100644 --- a/demos/meanfield_demo.py +++ b/demos/meanfield_demo.py @@ -58,24 +58,10 @@ def main(): if __name__ == '__main__': - # Choose space on which the signal field is defined - if len(sys.argv) == 2: - mode = int(sys.argv[1]) - else: - mode = 1 - - if mode == 0: - # One-dimensional regular grid with uniform exposure of 10 - position_space = ift.RGSpace(1024) - exposure = ift.Field.full(position_space, 10.) - elif mode == 1: - # Two-dimensional regular grid with inhomogeneous exposure - position_space = ift.RGSpace([512, 512]) - exposure = exposure_2d(position_space) - else: - # Sphere with uniform exposure of 100 - position_space = ift.HPSpace(128) - exposure = ift.Field.full(position_space, 100.) + + # Two-dimensional regular grid with inhomogeneous exposure + position_space = ift.RGSpace([10, 10]) + exposure = exposure_2d(position_space) # Define harmonic space and harmonic transform harmonic_space = position_space.get_default_codomain() @@ -120,10 +106,11 @@ if __name__ == '__main__': H = ift.StandardHamiltonian(likelihood) initial_position = ift.from_random(domain, 'normal') - meanfield_model = ift.MeanfieldModel(H.domain) - initial_position = meanfield_model.get_initial_pos() + # meanfield_model = ift.MeanfieldModel(H.domain) + fullcov_model = ift.FullCovarianceModel(H.domain) + initial_position = fullcov_model.get_initial_pos() position = initial_position - KL = ift.ParametricGaussianKL.make(initial_position,H,meanfield_model,3,False) + KL = ift.ParametricGaussianKL.make(initial_position,H,fullcov_model,3,False) plt.figure('data') plt.imshow(sky(mock_position).val) plt.pause(0.001) @@ -133,5 +120,5 @@ if __name__ == '__main__': position = KL.position plt.figure('result') plt.cla() - plt.imshow(sky(meanfield_model.generator(KL.position)).val) + plt.imshow(sky(fullcov_model.generator(KL.position)).val) plt.pause(0.001) \ No newline at end of file diff --git a/src/library/variational_models.py b/src/library/variational_models.py index e6441fc099c7399b57a1f21b91182cbd678510cf..ac405aba94c45bef30994d7b4676d7f18484e3fc 100644 --- a/src/library/variational_models.py +++ b/src/library/variational_models.py @@ -55,7 +55,7 @@ class FullCovarianceModel(): co_part = PartialExtractor(matmul_setup_dom, co.domain) lat_part = PartialExtractor(matmul_setup_dom, lat.domain) matmul_setup = lat_part.adjoint @ lat.adjoint @ lat + co_part.adjoint @ co.adjoint @ cov - + breakpoint() MatMult = MultiLinearEinsum(matmul_setup.domain,'ij,ki->jk', key_order=('co','latent')) Resp = Respacer(MatMult.target, mean.target)