From ce79980ce86a146561f1441dde6e8f76c9c33c8f Mon Sep 17 00:00:00 2001 From: Philipp Arras <parras@mpa-garching.mpg.de> Date: Thu, 7 Nov 2019 12:14:40 +0100 Subject: [PATCH] Add checks --- nifty5/library/correlated_fields.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py index 7785ba511..89ed31f43 100644 --- a/nifty5/library/correlated_fields.py +++ b/nifty5/library/correlated_fields.py @@ -53,11 +53,16 @@ def _normal(mean, sig, key): sig*ducktape(DomainTuple.scalar_domain(), None, key)) +def _log_k_lengths(pspace): + return np.log(pspace.k_lengths[1:]) + + class _SlopeRemover(EndomorphicOperator): def __init__(self, domain, logkl): self._domain = makeDomain(domain) + assert len(self._domain) == 1 + assert isinstance(self._domain[0], PowerSpace) self._sc = logkl/float(logkl[-1]) - self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): @@ -72,13 +77,11 @@ class _SlopeRemover(EndomorphicOperator): return from_global_data(self._tgt(mode), res) -def _log_k_lengths(pspace): - return np.log(pspace.k_lengths[1:]) - - class _TwoLogIntegrations(LinearOperator): def __init__(self, target): self._target = makeDomain(target) + assert len(self._target) == 1 + assert isinstance(self._target[0], PowerSpace) self._domain = makeDomain( UnstructuredDomain((2, self.target.shape[0] - 2))) self._capability = self.TIMES | self.ADJOINT_TIMES @@ -112,6 +115,8 @@ class _TwoLogIntegrations(LinearOperator): class _Normalization(Operator): def __init__(self, domain): self._domain = self._target = makeDomain(domain) + assert len(self._domain) == 1 + assert isinstance(domain[0], PowerSpace) hspace = self._domain[0].harmonic_partner pd = PowerDistributor(hspace, power_space=self._domain[0]) cst = pd.adjoint(full(pd.target, 1.)).to_global_data_rw() @@ -131,6 +136,7 @@ class _Normalization(Operator): class _SpecialSum(EndomorphicOperator): def __init__(self, domain): self._domain = makeDomain(domain) + assert len(self._domain) == 1 self._capability = self.TIMES | self.ADJOINT_TIMES def apply(self, x, mode): -- GitLab