Commit 31cd1173 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add MultiField support

parent e8618b90
......@@ -502,16 +502,23 @@ def calculate_position(operator, output):
raise TypeError
if isinstance(output, MultiField):
cov = 1e-3*max([vv.max() for vv in output.val.values()])**2
invcov = ScalingOperator(output.domain, cov).inverse
dtype = list(set([ff.dtype for ff in output.values()]))
if len(dtype) != 1:
raise ValueError('Only MultiFields with one dtype supported.')
dtype = dtype[0]
else:
cov = 1e-3*output.val.max()**2
dtype = output.dtype
invcov = ScalingOperator(output.domain, cov).inverse
d = output + invcov.draw_sample_with_dtype(dtype=output.dtype, from_inverse=True)
d = output + invcov.draw_sample(dtype, from_inverse=True)
lh = GaussianEnergy(d, invcov) @ operator
H = StandardHamiltonian(
lh, ic_samp=GradientNormController(iteration_limit=200))
pos = 0.1 * from_random(operator.domain, 'normal')
minimizer = NewtonCG(GradientNormController(iteration_limit=10))
pos = 0.1*from_random('normal', operator.domain)
minimizer = NewtonCG(GradientNormController(iteration_limit=10, name='findpos'))
for ii in range(3):
logger.info(f'Start iteration {ii+1}/3')
kl = MetricGaussianKL(pos, H, 3, mirror_samples=True)
kl, _ = minimizer(kl)
pos = kl.position
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment