Commit 0e8af4d4 authored by Philipp Arras's avatar Philipp Arras
Browse files

Simplifications SlopeOperator

parent 30f8f70c
......@@ -46,40 +46,27 @@ class SlopeOperator(LinearOperator):
sigmas : np.array, shape=(2,)
The slope variance and the y-intercept variance.
"""
def __init__(self, target):
if not isinstance(target, LogRGSpace):
raise TypeError
if len(target.shape) != 1:
raise ValueError("Slope Operator only works for ndim == 1")
self._domain = DomainTuple.make(UnstructuredDomain((2,)))
self._target = DomainTuple.make(target)
self._capability = self.TIMES | self.ADJOINT_TIMES
self.ndim = len(self.target[0].shape)
if self.ndim != 1:
raise ValueError("Slope Operator only works for ndim == 1")
# Prepare pos
self.pos = self.target[0].get_k_array()-self.target[0].t_0[0]
pos = self.target[0].get_k_array() - self.target[0].t_0[0]
self._pos = pos[0, 1:]
def apply(self, x, mode):
self._check_input(x, mode)
# Times
if mode == self.TIMES:
inp = x.to_global_data()
res = inp[-1]
for i in range(self.ndim):
res = res + inp[i] * self.pos[i]
res[0] = 0.
return Field.from_global_data(self.target, res)
# Adjoint times
res = np.zeros(self.domain[0].shape, dtype=x.dtype)
xglob = x.to_global_data()
res[-1] = np.sum(xglob[1:])
for i in range(self.ndim):
res[i] = np.sum(self.pos[i][1:] * xglob[1:])
return Field.from_global_data(self.domain, res)
if mode == self.TIMES:
res = np.empty(self.target.shape, dtype=x.dtype)
res[0] = 0
res[1:] = inp[1] + inp[0]*self._pos
else:
res = np.array(
[np.sum(self._pos*inp[1:]),
np.sum(inp[1:])], dtype=x.dtype)
return Field.from_global_data(self._tgt(mode), res)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment