From ce79980ce86a146561f1441dde6e8f76c9c33c8f Mon Sep 17 00:00:00 2001
From: Philipp Arras <parras@mpa-garching.mpg.de>
Date: Thu, 7 Nov 2019 12:14:40 +0100
Subject: [PATCH] Add checks

---
 nifty5/library/correlated_fields.py | 16 +++++++++++-----
 1 file changed, 11 insertions(+), 5 deletions(-)

diff --git a/nifty5/library/correlated_fields.py b/nifty5/library/correlated_fields.py
index 7785ba511..89ed31f43 100644
--- a/nifty5/library/correlated_fields.py
+++ b/nifty5/library/correlated_fields.py
@@ -53,11 +53,16 @@ def _normal(mean, sig, key):
         sig*ducktape(DomainTuple.scalar_domain(), None, key))
 
 
+def _log_k_lengths(pspace):
+    return np.log(pspace.k_lengths[1:])
+
+
 class _SlopeRemover(EndomorphicOperator):
     def __init__(self, domain, logkl):
         self._domain = makeDomain(domain)
+        assert len(self._domain) == 1
+        assert isinstance(self._domain[0], PowerSpace)
         self._sc = logkl/float(logkl[-1])
-
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
     def apply(self, x, mode):
@@ -72,13 +77,11 @@ class _SlopeRemover(EndomorphicOperator):
         return from_global_data(self._tgt(mode), res)
 
 
-def _log_k_lengths(pspace):
-    return np.log(pspace.k_lengths[1:])
-
-
 class _TwoLogIntegrations(LinearOperator):
     def __init__(self, target):
         self._target = makeDomain(target)
+        assert len(self._target) == 1
+        assert isinstance(self._target[0], PowerSpace)
         self._domain = makeDomain(
             UnstructuredDomain((2, self.target.shape[0] - 2)))
         self._capability = self.TIMES | self.ADJOINT_TIMES
@@ -112,6 +115,8 @@ class _TwoLogIntegrations(LinearOperator):
 class _Normalization(Operator):
     def __init__(self, domain):
         self._domain = self._target = makeDomain(domain)
+        assert len(self._domain) == 1
+        assert isinstance(domain[0], PowerSpace)
         hspace = self._domain[0].harmonic_partner
         pd = PowerDistributor(hspace, power_space=self._domain[0])
         cst = pd.adjoint(full(pd.target, 1.)).to_global_data_rw()
@@ -131,6 +136,7 @@ class _Normalization(Operator):
 class _SpecialSum(EndomorphicOperator):
     def __init__(self, domain):
         self._domain = makeDomain(domain)
+        assert len(self._domain) == 1
         self._capability = self.TIMES | self.ADJOINT_TIMES
 
     def apply(self, x, mode):
-- 
GitLab