diff --git a/nifty6/library/correlated_fields.py b/nifty6/library/correlated_fields.py index f99ab8d51c7332083f502c7395fa23b3ea0d79de..94f5633a3c29b5a50afc35bddd99a1e3cef2dec7 100644 --- a/nifty6/library/correlated_fields.py +++ b/nifty6/library/correlated_fields.py @@ -21,15 +21,17 @@ import numpy as np from ..domain_tuple import DomainTuple from ..domains.power_space import PowerSpace from ..domains.unstructured_domain import UnstructuredDomain +from ..field import Field +from ..multi_field import MultiField from ..operators.adder import Adder from ..operators.contraction_operator import ContractionOperator +from ..operators.diagonal_operator import DiagonalOperator from ..operators.distributors import PowerDistributor from ..operators.endomorphic_operator import EndomorphicOperator from ..operators.harmonic_operators import HarmonicTransformOperator from ..operators.linear_operator import LinearOperator -from ..operators.diagonal_operator import DiagonalOperator from ..operators.operator import Operator -from ..operators.simple_linear_operators import VdotOperator, ducktape +from ..operators.simple_linear_operators import ducktape from ..probing import StatCalculator from ..sugar import from_global_data, full, makeDomain @@ -49,8 +51,11 @@ def _lognormal_moments(mean, sig, N=0): mean, sig = np.asfarray(mean), np.asfarray(sig) else: mean, sig = (_reshaper(param, N) for param in (mean, sig)) - assert np.all(mean > 0) - assert np.all(sig > 0) + if not np.all(mean > 0): + raise ValueError(f"mean must be greater 0; got {mean!r}") + if not np.all(sig > 0): + raise ValueError(f"sig must be greater 0; got {sig!r}") + logsig = np.sqrt(np.log((sig/mean)**2 + 1)) logmean = np.log(mean) - logsig**2/2 return logmean, logsig @@ -63,9 +68,8 @@ def _normal(mean, sig, key, N=0): else: domain = UnstructuredDomain(N) mean, sig = (_reshaper(param, N) for param in (mean, sig)) - return Adder(from_global_data(domain, mean)) @ ( - DiagonalOperator(from_global_data(domain, sig)) - @ ducktape(domain, None, key)) + return Adder(from_global_data(domain, mean)) @ (DiagonalOperator( + from_global_data(domain, sig)) @ ducktape(domain, None, key)) def _log_k_lengths(pspace): @@ -99,13 +103,6 @@ def _total_fluctuation_realized(samples): return np.sqrt((res/len(samples)).mean()) -def _stats(op, samples): - sc = StatCalculator() - for s in samples: - sc.add(op(s.extract(op.domain))) - return sc.mean.val, sc.var.sqrt().val - - class _LognormalMomentMatching(Operator): def __init__(self, mean, sig, key, N_copies): key = str(key) @@ -146,7 +143,7 @@ class _SlopeRemover(EndomorphicOperator): else: res = x.copy() res[self._last] -= (x*self._sc[self._extender]).sum( - axis=self._space, keepdims=True) + axis=self._space, keepdims=True) return from_global_data(self._tgt(mode), res) @@ -199,11 +196,14 @@ class _Normalization(Operator): hspace = list(self._domain) hspace[space] = hspace[space].harmonic_partner hspace = makeDomain(hspace) - pd = PowerDistributor(hspace, power_space=self._domain[space], space=space) + pd = PowerDistributor(hspace, + power_space=self._domain[space], + space=space) mode_multiplicity = pd.adjoint(full(pd.target, 1.)).val.copy() zero_mode = (slice(None),)*self._domain.axes[space][0] + (0,) mode_multiplicity[zero_mode] = 0 - self._mode_multiplicity = from_global_data(self._domain, mode_multiplicity) + self._mode_multiplicity = from_global_data(self._domain, + mode_multiplicity) self._specsum = _SpecialSum(self._domain, space) def apply(self, x): @@ -322,8 +322,8 @@ class _Amplitude(Operator): op = Distributor @ op sig_fluc = Distributor @ sig_fluc op = Adder(Distributor(vol0)) @ (sig_fluc*(azm_expander @ azm.one_over())*op) - self._fluc = (_Distributor(dofdex, fluctuations.target, distributed_tgt[0]) @ - fluctuations) + self._fluc = (_Distributor(dofdex, fluctuations.target, + distributed_tgt[0]) @ fluctuations) else: op = Adder(vol0) @ (sig_fluc*(azm_expander @ azm.one_over())*op) self._fluc = fluctuations @@ -339,7 +339,8 @@ class _Amplitude(Operator): class CorrelatedFieldMaker: def __init__(self, amplitude_offset, prefix, total_N): - assert isinstance(amplitude_offset, Operator) + if not isinstance(amplitude_offset, Operator): + raise TypeError("amplitude_offset needs to be an operator") self._a = [] self._position_spaces = [] @@ -353,8 +354,8 @@ class CorrelatedFieldMaker: dofdex=None): if dofdex is None: dofdex = np.full(total_N, 0) - else: - assert len(dofdex) == total_N + elif len(dofdex) != total_N: + raise ValueError("length of dofdex needs to match total_N") N = max(dofdex) + 1 if total_N > 0 else 0 zm = _LognormalMomentMatching(offset_amplitude_mean, offset_amplitude_stddev, @@ -386,8 +387,8 @@ class CorrelatedFieldMaker: if dofdex is None: dofdex = np.full(self._total_N, 0) - else: - assert len(dofdex) == self._total_N + elif len(dofdex) != self._total_N: + raise ValueError("length of dofdex needs to match total_N") if self._total_N > 0: space = 1 @@ -427,14 +428,14 @@ class CorrelatedFieldMaker: def _finalize_from_op(self): n_amplitudes = len(self._a) if self._total_N > 0: - hspace = makeDomain([UnstructuredDomain(self._total_N)] + - [dd.target[-1].harmonic_partner - for dd in self._a]) + hspace = makeDomain( + [UnstructuredDomain(self._total_N)] + + [dd.target[-1].harmonic_partner for dd in self._a]) spaces = tuple(range(1, n_amplitudes + 1)) amp_space = 1 else: hspace = makeDomain( - [dd.target[0].harmonic_partner for dd in self._a]) + [dd.target[0].harmonic_partner for dd in self._a]) spaces = tuple(range(n_amplitudes)) amp_space = 0 @@ -451,14 +452,12 @@ class CorrelatedFieldMaker: pd = PowerDistributor(hspace, self._a[0].target[amp_space], amp_space) for i in range(1, n_amplitudes): - pd = (pd @ PowerDistributor(pd.domain, - self._a[i].target[amp_space], - space=spaces[i])) + pd = (pd @ PowerDistributor( + pd.domain, self._a[i].target[amp_space], space=spaces[i])) a = ContractionOperator(pd.domain, spaces[1:]).adjoint @ self._a[0] for i in range(1, n_amplitudes): - co = ContractionOperator(pd.domain, - spaces[:i] + spaces[i+1:]) + co = ContractionOperator(pd.domain, spaces[:i] + spaces[i + 1:]) a = a*(co.adjoint @ self._a[i]) return ht(azm*(pd @ a)*ducktape(hspace, None, self._prefix + 'xi')) @@ -467,23 +466,26 @@ class CorrelatedFieldMaker: """ offset vs zeromode: volume factor """ - if offset is not None: - raise NotImplementedError - offset = float(offset) - op = self._finalize_from_op() - if prior_info > 0: - from ..sugar import from_random - samps = [ - from_random('normal', op.domain) for _ in range(prior_info) - ] - self.statistics_summary(samps) + if offset is not None: + # Deviations from this offset must not be considered here as they + # are learned by the zeromode + if isinstance(offset, (Field, MultiField)): + op = Adder(offset) @ op + else: + offset = float(offset) + op = Adder(full(op.target, offset)) @ op + self.statistics_summary(prior_info) return op - def statistics_summary(self, samples): + def statistics_summary(self, prior_info): + from ..sugar import from_random + + if prior_info == 0: + return + lst = [('Offset amplitude', self.amplitude_total_offset), ('Total fluctuation amplitude', self.total_fluctuation)] - namps = len(self._a) if namps > 1: for ii in range(namps): @@ -493,13 +495,19 @@ class CorrelatedFieldMaker: self.average_fluctuation(ii))) for kk, op in lst: - mean, stddev = _stats(op, samples) + sc = StatCalculator() + for _ in range(prior_info): + sc.add(op(from_random('normal', op.domain))) + mean = sc.mean.val + stddev = sc.var.sqrt().val for m, s in zip(mean.flatten(), stddev.flatten()): print('{}: {:.02E} ± {:.02E}'.format(kk, m, s)) def moment_slice_to_average(self, fluctuations_slice_mean, nsamples=1000): fluctuations_slice_mean = float(fluctuations_slice_mean) - assert fluctuations_slice_mean > 0 + if not fluctuations_slice_mean > 0: + msg = "fluctuations_slice_mean must be greater zero; got {!r}" + raise ValueError(msg.format(fluctuations_slice_mean)) from ..sugar import from_random scm = 1. for a in self._a: @@ -545,7 +553,8 @@ class CorrelatedFieldMaker: """Returns operator which acts on prior or posterior samples""" if len(self._a) == 0: raise NotImplementedError - assert space < len(self._a) + if space >= len(self._a): + raise ValueError(f"invalid space specified; got {space!r}") if len(self._a) == 1: return self.average_fluctuation(0) q = 1. @@ -561,7 +570,8 @@ class CorrelatedFieldMaker: """Returns operator which acts on prior or posterior samples""" if len(self._a) == 0: raise NotImplementedError - assert space < len(self._a) + if space >= len(self._a): + raise ValueError(f"invalid space specified; got {space!r}") if len(self._a) == 1: return self._a[0].fluctuation_amplitude return self._a[space].fluctuation_amplitude @@ -582,7 +592,8 @@ class CorrelatedFieldMaker: """Computes slice fluctuations from collection of field (defined in signal space) realizations.""" ldom = len(samples[0].domain) - assert space < ldom + if space >= ldom: + raise ValueError(f"invalid space specified; got {space!r}") if ldom == 1: return _total_fluctuation_realized(samples) res1, res2 = 0., 0. @@ -599,7 +610,8 @@ class CorrelatedFieldMaker: """Computes average fluctuations from collection of field (defined in signal space) realizations.""" ldom = len(samples[0].domain) - assert space < ldom + if space >= ldom: + raise ValueError(f"invalid space specified; got {space!r}") if ldom == 1: return _total_fluctuation_realized(samples) spaces = () diff --git a/nifty6/multi_field.py b/nifty6/multi_field.py index 744a9c4887e1014b891a38124eec45d2c1764944..b78caa5a0da725605ee9a4773ad229996a481c4f 100644 --- a/nifty6/multi_field.py +++ b/nifty6/multi_field.py @@ -132,6 +132,10 @@ class MultiField(object): return MultiField(domain, tuple(Field(dom, val) for dom in domain._domains)) + def to_global_data(self): + return {key: val.val + for key, val in zip(self._domain.keys(), self._val)} + @staticmethod def from_global_data(domain, arr): return MultiField(