Commit 7db6d4f9 authored by Martin Reinecke's avatar Martin Reinecke

compatification

parent 12c44bb4
......@@ -25,9 +25,8 @@ class FieldZeroPadder(LinearOperator):
raise ValueError("Shape mismatch")
if any([a < b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must be larger than old shape")
tgt = RGSpace(new_shape, dom.distances)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target[self._space] = RGSpace(new_shape, dom.distances)
self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
......
......@@ -68,60 +68,38 @@ class LaplaceOperator(EndomorphicOperator):
self._dposc[1:] += self._dpos
self._dposc *= 0.5
def _times(self, x):
def apply(self, x, mode):
self._check_input(x, mode)
axes = x.domain.axes[self._space]
axis = axes[0]
locval = x.val
if axis == dobj.distaxis(locval):
locval = dobj.redistribute(locval, nodist=(axis,))
val = dobj.local_data(locval)
nval = len(self._dposc)
prefix = (slice(None),) * axis
sl_l = prefix + (slice(None, -1),) # "left" slice
sl_r = prefix + (slice(1, None),) # "right" slice
dpos = self._dpos.reshape((1,)*axis + (nval-1,))
dposc = self._dposc.reshape((1,)*axis + (nval,))
deriv = (val[sl_r]-val[sl_l])/dpos # defined between points
locval = x.val
if axis == dobj.distaxis(locval):
locval = dobj.redistribute(locval, nodist=(axis,))
val = dobj.local_data(locval)
ret = np.empty_like(val)
ret[sl_l] = deriv
ret[prefix + (-1,)] = 0.
ret[sl_r] -= deriv
ret /= dposc
ret[prefix + (slice(None, 2),)] = 0.
ret[prefix + (-1,)] = 0.
if mode == self.TIMES:
deriv = (val[sl_r]-val[sl_l])/dpos # defined between points
ret[sl_l] = deriv
ret[prefix + (-1,)] = 0.
ret[sl_r] -= deriv
ret /= dposc
ret[prefix + (slice(None, 2),)] = 0.
ret[prefix + (-1,)] = 0.
else:
val = val/dposc
val[prefix + (slice(None, 2),)] = 0.
val[prefix + (-1,)] = 0.
deriv = (val[sl_r]-val[sl_l])/dpos # defined between points
ret[sl_l] = deriv
ret[prefix + (-1,)] = 0.
ret[sl_r] -= deriv
ret = dobj.from_local_data(locval.shape, ret, dobj.distaxis(locval))
if dobj.distaxis(locval) != dobj.distaxis(x.val):
ret = dobj.redistribute(ret, dist=dobj.distaxis(x.val))
return Field(self.domain, val=ret)
def _adjoint_times(self, x):
axes = x.domain.axes[self._space]
axis = axes[0]
nval = len(self._dposc)
prefix = (slice(None),) * axis
sl_l = prefix + (slice(None, -1),) # "left" slice
sl_r = prefix + (slice(1, None),) # "right" slice
dpos = self._dpos.reshape((1,)*axis + (nval-1,))
dposc = self._dposc.reshape((1,)*axis + (nval,))
yf = x.val
if axis == dobj.distaxis(yf):
yf = dobj.redistribute(yf, nodist=(axis,))
y = dobj.local_data(yf)
y = y/dposc
y[prefix + (slice(None, 2),)] = 0.
y[prefix + (-1,)] = 0.
deriv = (y[sl_r]-y[sl_l])/dpos # defined between points
ret = np.empty_like(y)
ret[sl_l] = deriv
ret[prefix + (-1,)] = 0.
ret[sl_r] -= deriv
ret = dobj.from_local_data(x.shape, ret, dobj.distaxis(yf))
if dobj.distaxis(yf) != dobj.distaxis(x.val):
ret = dobj.redistribute(ret, dist=dobj.distaxis(x.val))
return Field(self.domain, val=ret)
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return self._times(x)
return self._adjoint_times(x)
......@@ -119,9 +119,7 @@ class GeometryRemover(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return x.cast_domain(self._target)
return x.cast_domain(self._domain)
return x.cast_domain(self._tgt(mode))
class NullOperator(LinearOperator):
......@@ -150,7 +148,4 @@ class NullOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if mode == self.TIMES:
return self._nullfield(self._target)
return self._nullfield(self._domain)
return self._nullfield(self._tgt(mode))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment