Skip to content
Snippets Groups Projects
Commit b0f63234 authored by Philipp Arras's avatar Philipp Arras
Browse files

Merge remote-tracking branch 'origin/master' into devel

parents 89689d1d 846e8899
No related branches found
No related tags found
1 merge request!27Devel
......@@ -18,9 +18,7 @@ def simple_minimize(
class Minimization:
def __init__(
self, operator, position, n_samples, constants=[], point_estimates=[], comm=None
):
def __init__(self, operator, position, n_samples, constants=[], point_estimates=[], comm=None):
n_samples = int(n_samples)
self._position = position
position = position.extract(operator.domain)
......@@ -33,13 +31,14 @@ class Minimization:
"mean": position,
"hamiltonian": operator,
"n_samples": n_samples,
"minimizer_sampling": None,
"constants": constants,
"point_estimates": point_estimates,
"mirror_samples": True,
"comm": comm,
"nanisinf": True,
}
self._e = ift.MetricGaussianKL(**dct)
self._e = ift.SampledKL(**dct)
self._n, self._m = dct["n_samples"], dct["mirror_samples"]
def minimize(self, minimizer):
......
......@@ -24,7 +24,7 @@ def getop(comm, typ):
ddom = ift.UnstructuredDomain(d[0].shape)
ops = [
ift.GaussianEnergy(
ift.makeField(ddom, d[ii]), ift.makeOp(ift.makeField(ddom, invcov[ii]))
ift.makeField(ddom, d[ii]), ift.makeOp(ift.makeField(ddom, invcov[ii]), sampling_dtype=d[ii].dtype)
)
@ ift.DomainTupleFieldInserter(skydom, 0, (ii,)).adjoint
for ii in range(nwork)
......@@ -39,7 +39,7 @@ def getop(comm, typ):
for ii in local_indices:
ddom = ift.UnstructuredDomain(d[ii].shape)
dd = ift.makeField(ddom, d[ii])
iicc = ift.makeOp(ift.makeField(ddom, invcov[ii]))
iicc = ift.makeOp(ift.makeField(ddom, invcov[ii]), sampling_dtype=d[ii].dtype)
ee = ift.GaussianEnergy(dd, iicc)
if typ == 0:
ee = ee @ ift.DomainTupleFieldInserter(skydom, 0, (ii,)).adjoint
......@@ -117,6 +117,7 @@ def test_mpi_adder():
lin = ift.Linearization.make_var(ift.from_random(dom), True)
samps_lh, samps_ham = [], []
for ii, (llhh, hh) in enumerate(zip(lhs_for_sampling, hams_for_sampling)):
print(ii)
with ift.random.Context(42):
samps_lh.append(llhh(lin).metric.draw_sample())
with ift.random.Context(42):
......@@ -127,7 +128,5 @@ def test_mpi_adder():
mini_results = []
for ham in hams_for_sampling:
with ift.random.Context(42):
mini_results.append(
mini(ift.MetricGaussianKL(pos, ham, 3, True))[0].position
)
mini_results.append(mini(ift.SampledKLEnergy(pos, ham, 3, None))[0].position)
allclose(mini_results)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment