Commit f9c22026 authored by Martin Reinecke's avatar Martin Reinecke

cleanup SlopeOperator interface

parent 412b149a
......@@ -106,23 +106,21 @@ class AmplitudeModel(Operator):
from ..operators.symmetrizing_operator import SymmetrizingOperator
h_space = s_space.get_default_codomain()
p_space = PowerSpace(h_space)
self._exp_transform = ExpTransform(p_space, Npixdof)
self._exp_transform = ExpTransform(PowerSpace(h_space), Npixdof)
logk_space = self._exp_transform.domain[0]
qht = QHTOperator(target=logk_space)
dof_space = qht.domain[0]
param_space = UnstructuredDomain(2)
sym = SymmetrizingOperator(logk_space)
phi_mean = np.array([sm, im])
phi_sig = np.array([sv, iv])
self._slope = SlopeOperator(param_space, logk_space, phi_sig)
self._norm_phi_mean = Field.from_global_data(param_space,
self._slope = SlopeOperator(logk_space, phi_sig)
self._norm_phi_mean = Field.from_global_data(self._slope.domain,
self._domain = MultiDomain.make({keys[0]: dof_space,
keys[1]: param_space})
keys[1]: self._slope.domain})
self._target =
kern = lambda k: _ceps_kernel(dof_space, k, ceps_a, ceps_k)
......@@ -52,8 +52,7 @@ class QHTOperator(LinearOperator):
raise ValueError("target[space] has to be a LogRGSpace!")
if self._target[self._space].harmonic:
raise TypeError(
"target[space] must be a nonharmonic space")
raise TypeError("target[space] must be a nonharmonic space")
self._domain = [dom for dom in self._target]
self._domain[self._space] = \
......@@ -34,31 +34,25 @@ class SlopeOperator(LinearOperator):
This operator creates a field on a LogRGSpace, which is created
according to a slope of given entries, (mean, y-intercept).
The slope mean is the powerlaw of the field in normal-space.
The slope mean is the power law of the field in normal-space.
domain : domain or DomainTuple, shape=(2,)
It has to be and UnstructuredDomain.
It has to be an UnstructuredDomain.
The domain of the slope mean and the y-intercept mean.
target : domain or DomainTuple
The output domain has to a LogRGSpace
sigmas : np.array, shape=(2,)
The slope variance and the y-intercept variance.
def __init__(self, domain, target, sigmas):
def __init__(self, target, sigmas):
if not isinstance(target, LogRGSpace):
raise TypeError
if not (isinstance(domain, UnstructuredDomain) and domain.shape == (2,)):
raise TypeError
self._domain = DomainTuple.make(domain)
self._domain = DomainTuple.make(UnstructuredDomain((2,)))
self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
if self.domain[0].shape != (len([0].shape) + 1,):
raise AssertionError("Shape mismatch!")
self._sigmas = sigmas
self.ndim = len([0].shape)
self.pos = np.zeros((self.ndim,) +[0].shape)
......@@ -74,8 +74,7 @@ class Consistency_Tests(unittest.TestCase):
tmp = ift.ExpTransform(ift.PowerSpace(args[0]), args[1], args[2])
tgt = tmp.domain[0]
sig = np.array([0.3, 0.13])
dom = ift.UnstructuredDomain(2)
op = ift.SlopeOperator(dom, tgt, sig)
op = ift.SlopeOperator(tgt, sig)
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment