diff --git a/nifty5/domains/log_rg_space.py b/nifty5/domains/log_rg_space.py index 3aabf68de347ffcbc574a5819ace5fd9a86ce068..740ac391b1500cca6c510ad42f6047a324fa1143 100644 --- a/nifty5/domains/log_rg_space.py +++ b/nifty5/domains/log_rg_space.py @@ -74,11 +74,9 @@ class LogRGSpace(StructuredDomain): % (self.shape, self.harmonic)) def get_default_codomain(self): - if self._harmonic: - raise ValueError("only supported for nonharmonic space") codomain_bindistances = 1. / (self.bindistances * self.shape) return LogRGSpace(self.shape, codomain_bindistances, - np.zeros(len(self.shape)), True) + self._t_0, True) def get_k_length_array(self): ib = dobj.ibegin_from_shape(self._shape) diff --git a/nifty5/library/amplitude_model.py b/nifty5/library/amplitude_model.py index 5f599503a22187c4cbed09efd8f3a07aced06427..1aa00ed24003d18561c13bca2e4949022b355b85 100644 --- a/nifty5/library/amplitude_model.py +++ b/nifty5/library/amplitude_model.py @@ -65,9 +65,9 @@ def make_amplitude_model(s_space, Npixdof, ceps_a, ceps_k, sm, sv, im, iv, p_space = PowerSpace(h_space) exp_transform = ExpTransform(p_space, Npixdof) logk_space = exp_transform.domain[0] - dof_space = logk_space.get_default_codomain() + qht = QHTOperator(target=logk_space) + dof_space = qht.domain[0] param_space = UnstructuredDomain(2) - qht = QHTOperator(dof_space, logk_space) sym = SymmetrizingOperator(logk_space) phi_mean = np.array([sm, im]) diff --git a/nifty5/operators/qht_operator.py b/nifty5/operators/qht_operator.py index 3eebc338146c167b3ebc40b837291b00accccf0d..b5a30bbb8762b5a1071787eff61ae93510b33098 100644 --- a/nifty5/operators/qht_operator.py +++ b/nifty5/operators/qht_operator.py @@ -36,34 +36,28 @@ class QHTOperator(LinearOperator): Parameters ---------- - domain : domain, tuple of domains or DomainTuple - The full input domain + target : domain, tuple of domains or DomainTuple + The full output domain space : int The index of the domain on which the operator acts. - domain[space] must be a harmonic LogRGSpace. - target : LogRGSpace - The target codomain of domain[space] - Must be a nonharmonic LogRGSpace. + target[space] must be a nonharmonic LogRGSpace. """ - def __init__(self, domain, target, space=0): - self._domain = DomainTuple.make(domain) - self._space = infer_space(self._domain, space) + def __init__(self, target, space=0): + self._target = DomainTuple.make(target) + self._space = infer_space(self._target, space) from ..domains.log_rg_space import LogRGSpace - if not isinstance(self._domain[self._space], LogRGSpace): - raise ValueError("Domain[space] has to be a LogRGSpace!") - if not isinstance(target, LogRGSpace): - raise ValueError("Target has to be a LogRGSpace!") + if not isinstance(self._target[self._space], LogRGSpace): + raise ValueError("target[space] has to be a LogRGSpace!") - if not self._domain[self._space].harmonic: + if self._target[self._space].harmonic: raise TypeError( - "QHTOperator only works on a harmonic space") - if target.harmonic: - raise TypeError("Target is not a codomain of domain") + "target[space] must be a nonharmonic space") - self._target = [dom for dom in self._domain] - self._target[self._space] = target - self._target = DomainTuple.make(self._target) + self._domain = [dom for dom in self._target] + self._domain[self._space] = \ + self._target[self._space].get_default_codomain() + self._domain = DomainTuple.make(self._domain) @property def domain(self): diff --git a/test/test_operators/test_adjoint.py b/test/test_operators/test_adjoint.py index f6f9e9542e5cd41d35eb7d9444164e663b2513cd..74cf970c27167a6b13095e603ff31784c9fc6903 100644 --- a/test/test_operators/test_adjoint.py +++ b/test/test_operators/test_adjoint.py @@ -143,8 +143,6 @@ class Consistency_Tests(unittest.TestCase): ift.LogRGSpace(17, [3.], [.7])), 1)], [np.float64])) def testQHTOperator(self, args, dtype): - dom = ift.DomainTuple.make(args[0]) - tgt = list(dom) - tgt[args[1]] = tgt[args[1]].get_default_codomain() - op = ift.QHTOperator(tgt, dom[args[1]], args[1]) + tgt = ift.DomainTuple.make(args[0]) + op = ift.QHTOperator(tgt, args[1]) ift.extra.consistency_check(op, dtype, dtype)