Commit 2ad2d650 authored by Philipp Arras's avatar Philipp Arras
Browse files

Iterate only once through KL samples

parent f40d7d51
Pipeline #88122 passed with stages
in 11 minutes and 22 seconds
......@@ -127,7 +127,7 @@ def main():
KL, convergence = minimizer(KL)
mean = KL.position
ift.extra.minisanity(data, lambda x: N.inverse, signal_response,
KL.position, list(KL.samples))
KL.position, KL.samples)
# Plot current reconstruction
plot = ift.Plot()
......
......@@ -431,7 +431,6 @@ def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None):
samples : list or tuple of Field or MultiField, optional
Residual samples around `mean`. Default: no samples.
# FIXME @mtr probably we need MPI support here, right?
Note
----
......@@ -444,7 +443,6 @@ def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None):
is_operator(modeldata_operator)
and is_fieldlike(data)
and is_fieldlike(mean)
and isinstance(samples, (list, tuple))
):
raise TypeError
keylen = 18
......@@ -452,9 +450,30 @@ def minisanity(data, metric_at_pos, modeldata_operator, mean, samples=None):
if isinstance(dom, MultiDomain):
keylen = max([max(map(len, dom.keys())), keylen])
keylen = min([keylen, 42])
normresi = metric_at_pos(mean).get_sqrt() @ Adder(data, neg=True) @ modeldata_operator
s0 = _comp_chisq(normresi, mean, samples, keylen)
s1 = _comp_chisq(ScalingOperator(mean.domain, 1), mean, samples, keylen)
op0 = metric_at_pos(mean).get_sqrt() @ Adder(data, neg=True) @ modeldata_operator
op1 = ScalingOperator(mean.domain, 1)
if not isinstance(op0.target, MultiDomain):
op0 = op0.ducktape_left("<None>")
if not isinstance(op1.target, MultiDomain):
op1 = op1.ducktape_left("<None>")
s = [full(mean.domain, 0.0)] if samples is None else samples
xop = op0, op1
xkeys = op0.target.keys(), op1.target.keys()
xredchisq, xscmean, xndof = 2*[None], 2*[None], 2*[None]
for aa in [0, 1]:
xredchisq[aa] = {kk: StatCalculator() for kk in xkeys[aa]}
xscmean[aa] = {kk: StatCalculator() for kk in xkeys[aa]}
xndof[aa] = {}
for ii, ss in enumerate(s):
for aa in [0, 1]:
rr = xop[aa].force(mean.unite(ss))
for kk in xkeys[aa]:
xredchisq[aa][kk].add(np.nansum(abs(rr[kk].val) ** 2) / rr[kk].size)
xscmean[aa][kk].add(np.nanmean(rr[kk].val))
xndof[aa][kk] = rr[kk].size - np.sum(np.isnan(rr[kk].val))
s0 = _tableentries(xredchisq[0], xscmean[0], xndof[0], keylen)
s1 = _tableentries(xredchisq[1], xscmean[1], xndof[1], keylen)
f = logger.info
n = 38 + keylen
......@@ -478,21 +497,7 @@ class _bcolors:
BOLD = "\033[1m"
def _comp_chisq(op, p, samples, keylen):
s = [full(p.domain, 0.0)] if samples is None else samples
mf = isinstance(op.target, MultiDomain)
if not mf:
op = op.ducktape_left("<None>")
keys = op.target.keys()
redchisq = {kk: StatCalculator() for kk in keys}
mean = {kk: StatCalculator() for kk in keys}
ndof = {}
for ii, ss in enumerate(s):
rr = op.force(p.unite(ss))
for kk in keys:
redchisq[kk].add(np.nansum(abs(rr[kk].val) ** 2) / rr[kk].size)
mean[kk].add(np.nanmean(rr[kk].val))
ndof[kk] = rr[kk].size - np.sum(np.isnan(rr[kk].val))
def _tableentries(redchisq, scmean, ndof, keylen):
out = ""
for kk in redchisq.keys():
if len(kk) > keylen:
......@@ -511,9 +516,9 @@ def _comp_chisq(op, p, samples, keylen):
else:
out += f"{foo:>11}"
foo = f"{mean[kk].mean:.1f}"
foo = f"{scmean[kk].mean:.1f}"
try:
foo += f" ± {np.sqrt(mean[kk].var):.1f}"
foo += f" ± {np.sqrt(scmean[kk].var):.1f}"
except RuntimeError:
pass
out += f"{foo:>14}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment