diff --git a/nifty6/library/correlated_fields.py b/nifty6/library/correlated_fields.py index 5571f0819957605f1efa7a1cd27cc83bafed4300..d02a80615a508459897d1356d224223e1acd4105 100644 --- a/nifty6/library/correlated_fields.py +++ b/nifty6/library/correlated_fields.py @@ -101,13 +101,6 @@ def _total_fluctuation_realized(samples): return np.sqrt((res/len(samples)).mean()) -def _stats(op, samples): - sc = StatCalculator() - for s in samples: - sc.add(op(s.extract(op.domain))) - return sc.mean.to_global_data(), sc.var.sqrt().to_global_data() - - class _LognormalMomentMatching(Operator): def __init__(self, mean, sig, key, N_copies): key = str(key) @@ -478,19 +471,17 @@ class CorrelatedFieldMaker: else: offset = float(offset) op = Adder(full(op.target, offset)) @ op - - if prior_info > 0: - from ..sugar import from_random - samps = [ - from_random('normal', op.domain) for _ in range(prior_info) - ] - self.statistics_summary(samps) + self.statistics_summary(prior_info) return op - def statistics_summary(self, samples): + def statistics_summary(self, prior_info): + from ..sugar import from_random + + if prior_info == 0: + return + lst = [('Offset amplitude', self.amplitude_total_offset), ('Total fluctuation amplitude', self.total_fluctuation)] - namps = len(self._a) if namps > 1: for ii in range(namps): @@ -500,7 +491,11 @@ class CorrelatedFieldMaker: self.average_fluctuation(ii))) for kk, op in lst: - mean, stddev = _stats(op, samples) + sc = StatCalculator() + for _ in range(prior_info): + sc.add(op(from_random('normal', op.domain))) + mean = sc.mean.to_global_data() + stddev = sc.var.sqrt().to_global_data() for m, s in zip(mean.flatten(), stddev.flatten()): print('{}: {:.02E} ± {:.02E}'.format(kk, m, s))