Commit 47b92020 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

first attempt

parent 123f1b58
......@@ -52,7 +52,7 @@ class DomainTuple(object):
nax = len(thing.shape)
res[idx] = tuple(range(i, i+nax))
i += nax
return res
return tuple(res)
@staticmethod
def make(domain):
......
......@@ -30,17 +30,19 @@ from .linear_operator import LinearOperator
class ExpTransform(LinearOperator):
def __init__(self, target, dof):
if not ((isinstance(target, RGSpace) and target.harmonic) or
isinstance(target, PowerSpace)):
def __init__(self, target, dof, space=0):
self._target = DomainTuple.make(target)
self._space = int(space)
tgt = self._target[self._space]
if not ((isinstance(tgt, RGSpace) and tgt.harmonic) or
isinstance(tgt, PowerSpace)):
raise ValueError(
"Target must be a harmonic RGSpace or a power space.")
if np.isscalar(dof):
dof = np.full(len(target.shape), int(dof), dtype=np.int)
dof = np.full(len(tgt.shape), int(dof), dtype=np.int)
dof = np.array(dof)
ndim = len(target.shape)
ndim = len(tgt.shape)
t_mins = np.empty(ndim)
bindistances = np.empty(ndim)
......@@ -48,12 +50,12 @@ class ExpTransform(LinearOperator):
self._frac = [None] * ndim
for i in range(ndim):
if isinstance(target, RGSpace):
rng = np.arange(target.shape[i])
tmp = np.minimum(rng, target.shape[i]+1-rng)
k_array = tmp * target.distances[i]
if isinstance(tgt, RGSpace):
rng = np.arange(tgt.shape[i])
tmp = np.minimum(rng, tgt.shape[i]+1-rng)
k_array = tmp * tgt.distances[i]
else:
k_array = target.k_lengths
k_array = tgt.k_lengths
# avoid taking log of first entry
log_k_array = np.log(k_array[1:])
......@@ -77,8 +79,9 @@ class ExpTransform(LinearOperator):
from ..domains.log_rg_space import LogRGSpace
log_space = LogRGSpace(2*dof+1, bindistances,
t_mins, harmonic=False)
self._target = DomainTuple.make(target)
self._domain = DomainTuple.make(log_space)
self._domain = [dom for dom in self._target]
self._domain[self._space] = log_space
self._domain = DomainTuple.make(self._domain)
@property
def domain(self):
......@@ -94,9 +97,10 @@ class ExpTransform(LinearOperator):
ax = dobj.distaxis(x)
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
for d in range(ndim):
d0 = self._target.axes[self._space][0]
for d in self._target.axes[self._space]:
idx = (slice(None,),) * d
wgt = self._frac[d].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
if d == ax:
x = dobj.redistribute(x, nodist=(ax,))
......@@ -107,11 +111,11 @@ class ExpTransform(LinearOperator):
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype)
np.add.at(xnew, idx + (self._bindex[d],), x * (1.-wgt))
np.add.at(xnew, idx + (self._bindex[d]+1,), x * wgt)
np.add.at(xnew, idx + (self._bindex[d-d0],), x * (1.-wgt))
np.add.at(xnew, idx + (self._bindex[d-d0]+1,), x * wgt)
else: # TIMES
xnew = x[idx + (self._bindex[d],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d]+1,)] * wgt
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
curshp[d] = self._tgt(mode).shape[d]
x = dobj.from_local_data(curshp, xnew, distaxis=curax)
......
......@@ -36,26 +36,34 @@ class QHTOperator(LinearOperator):
Parameters
----------
domain : LogRGSpace
The domain needs to be a LogRGSpace.
domain : domain, tuple of domains or DomainTuple
The full input domain
space : int
The index of the domain on which the operator acts.
domain[space] must be a harmonic LogRGSpace.
target : LogRGSpace
The target needs to be a LogRGSpace.
The target codomain of domain[space]
Must be a nonharmonic LogRGSpace.
"""
def __init__(self, domain, target):
if not domain.harmonic:
raise TypeError(
"HarmonicTransformOperator only works on a harmonic space")
if target.harmonic:
raise TypeError("Target is not a codomain of domain")
def __init__(self, domain, target, space=0):
self._domain = DomainTuple.make(domain)
self._space = int(space)
from ..domains.log_rg_space import LogRGSpace
if not isinstance(domain, LogRGSpace):
raise ValueError("Domain has to be a LogRGSpace!")
if not isinstance(self._domain[self._space], LogRGSpace):
raise ValueError("Domain[space] has to be a LogRGSpace!")
if not isinstance(target, LogRGSpace):
raise ValueError("Target has to be a LogRGSpace!")
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(target)
if not self._domain[self._space].harmonic:
raise TypeError(
"HarmonicTransformOperator only works on a harmonic space")
if target.harmonic:
raise TypeError("Target is not a codomain of domain")
self._target = [dom for dom in self._domain]
self._target[self._space] = target
self._target = DomainTuple.make(self._target)
@property
def domain(self):
......@@ -67,9 +75,10 @@ class QHTOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
x = x.val * self.domain[0].scalar_dvol()
n = len(self.domain[0].shape)
rng = range(n) if mode == self.TIMES else reversed(range(n))
dom = self._domain[self._space]
x = x.val * dom.scalar_dvol()
n = self._domain.axes[self._space]
rng = n if mode == self.TIMES else reversed(n)
ax = dobj.distaxis(x)
globshape = x.shape
for i in rng:
......
......@@ -27,11 +27,12 @@ from .endomorphic_operator import EndomorphicOperator
class SymmetrizingOperator(EndomorphicOperator):
def __init__(self, domain):
if not (isinstance(domain, LogRGSpace) and not domain.harmonic):
raise TypeError
def __init__(self, domain, space=0):
self._domain = DomainTuple.make(domain)
self._ndim = len(self.domain.shape)
self._space = int(space)
dom = self._domain[self._space]
if not (isinstance(dom, LogRGSpace) and not dom.harmonic):
raise TypeError
@property
def domain(self):
......@@ -42,7 +43,7 @@ class SymmetrizingOperator(EndomorphicOperator):
tmp = x.val.copy()
ax = dobj.distaxis(tmp)
globshape = tmp.shape
for i in range(self._ndim):
for i in self._domain.axes[self._space]:
lead = (slice(None),)*i
if i == ax:
tmp = dobj.redistribute(tmp, nodist=(ax,))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment