Commit 0e487953 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweak operator

parent 9b53ee63
......@@ -111,7 +111,7 @@ class ExpTransform(LinearOperator):
curshp = list(self._dom(mode).shape)
d0 = self._target.axes[self._space][0]
for d in self._target.axes[self._space]:
idx = (slice(None,),) * d
idx = (slice(None),) * d
wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
if d == ax:
......
......@@ -19,84 +19,76 @@
from __future__ import absolute_import, division, print_function
import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import aslinearoperator
from .. import dobj
from ..compat import *
from ..domain_tuple import DomainTuple
from ..domains.rg_space import RGSpace
from ..field import Field
from ..utilities import infer_space, special_add_at
from .linear_operator import LinearOperator
class RegriddingOperator(LinearOperator):
def __init__(self, domain, target):
super(RegriddingOperator, self).__init__()
def __init__(self, domain, new_shape, space=0):
self._domain = DomainTuple.make(domain)
self._target = DomainTuple.make(target)
if any(np.array(self.domain.shape) < np.array(self.target.shape)):
print('Warning: The regridding operator is not intended to be used for upsampling.')
self._space = infer_space(self._domain, space)
dom = self._domain[self._space]
# domain: fine domain
# target: coarse domain
distances_dom = sum([list(dom.distances) for dom in self.domain], [])
distances_tgt = sum([list(dom.distances) for dom in self.target], [])
dim = len(distances_tgt)
if not isinstance(dom, RGSpace):
raise TypeError("RGSpace required")
if len(new_shape) != len(dom.shape):
print(new_shape, dom.shape)
raise ValueError("Shape mismatch")
if any([a > b for a, b in zip(new_shape, dom.shape)]):
raise ValueError("New shape must not be larger than old shape")
# index arrays
dom_indices = np.arange(self.domain.size).reshape(self.domain.shape)
tgt_indices = np.arange(self.target.size).reshape(self.target.shape)
newdist = tuple(dom.distances[i]*dom.shape[i]/new_shape[i]
for i in range(len(dom.shape)))
# Input for sparse matrix
foo = (self.domain.size, 2**len(self.domain.shape))
rs, cs, ws = np.zeros(foo), np.zeros(foo), np.zeros(foo)
tgt = RGSpace(new_shape, newdist)
self._target = list(self._domain)
self._target[self._space] = tgt
self._target = DomainTuple.make(self._target)
self._capability = self.TIMES | self.ADJOINT_TIMES
print('Initializing...')
# Calculate weights
fac = np.array(distances_dom)/distances_tgt
find_neighbours = np.array(
np.meshgrid(*[[0, 1] for _ in range(dim)])).T.reshape(-1, dim)
for ind, global_index in np.ndenumerate(dom_indices):
p_in_tgt = np.outer(ind*fac, np.ones(2**dim)).T
neighbours = p_in_tgt.astype(int)+find_neighbours
ws[global_index] = np.prod(1-np.abs(neighbours-p_in_tgt), axis=1)
cs[global_index] = dom_indices[tuple(
np.array(ind) % self.domain.shape)]
rs[global_index] = [
tgt_indices[tuple(n % self.target.shape)] for n in neighbours
]
if global_index % 10000 == 9999:
print('{}%'.format(np.round(global_index/dom_indices.size*100), 3))
print('Done')
# Throw away zero weights and flatten at the same time
mask = ws != 0
rs, cs, ws = rs[mask], cs[mask], ws[mask]
if np.sum(ws) != self.domain.size:
raise RuntimeError
smat = csr_matrix(
(ws, (rs, cs)), shape=(self.target.size, self.domain.size))
self._smat = aslinearoperator(smat)
ndim = len(new_shape)
bindistances = np.empty(ndim)
self._bindex = [None] * ndim
self._frac = [None] * ndim
for d in range(ndim):
tmp = np.arange(new_shape[d])*(newdist[d]/dom.distances[d])
self._bindex[d] = np.minimum(dom.shape[d]-2,tmp.astype(np.int))
self._frac = tmp-self._bindex[d]
def apply(self, x, mode):
self._check_input(x, mode)
inp = x.to_global_data()
if mode == self.TIMES:
res = self._smat.matvec(inp.reshape(-1))
else:
res = self._smat.rmatvec(inp.reshape(-1))
res *= self.target.size/self.domain.size
tgt = self._tgt(mode)
return Field.from_global_data(tgt, res.reshape(tgt.shape))
x = x.val
ax = dobj.distaxis(x)
ndim = len(self.target.shape)
curshp = list(self._dom(mode).shape)
d0 = self._target.axes[self._space][0]
for d in self._target.axes[self._space]:
idx = (slice(None),) * d
wgt = self._frac[d-d0].reshape((1,)*d + (-1,) + (1,)*(ndim-d-1))
@property
def domain(self):
return self._domain
if d == ax:
x = dobj.redistribute(x, nodist=(ax,))
curax = dobj.distaxis(x)
x = dobj.local_data(x)
@property
def target(self):
return self._target
if mode == self.ADJOINT_TIMES:
shp = list(x.shape)
shp[d] = self._tgt(mode).shape[d]
xnew = np.zeros(shp, dtype=x.dtype)
xnew = special_add_at(xnew, d, self._bindex[d-d0], x*(1.-wgt))
xnew = special_add_at(xnew, d, self._bindex[d-d0]+1, x*wgt)
else: # TIMES
xnew = x[idx + (self._bindex[d-d0],)] * (1.-wgt)
xnew += x[idx + (self._bindex[d-d0]+1,)] * wgt
@property
def capability(self):
return self.TIMES | self.ADJOINT_TIMES
curshp[d] = self._tgt(mode).shape[d]
x = dobj.from_local_data(curshp, xnew, distaxis=curax)
if d == ax:
x = dobj.redistribute(x, dist=ax)
return Field(self._tgt(mode), val=x)
......@@ -241,9 +241,11 @@ class Consistency_Tests(unittest.TestCase):
op = ift.QHTOperator(tgt, args[1])
ift.extra.consistency_check(op, dtype, dtype)
@expand([[ift.RGSpace(13, 52, 40), ift.RGSpace(4, 6, 25)],
[ift.RGSpace(128, 128), ift.RGSpace(45, 48)],
[ift.RGSpace(13), ift.RGSpace(7)]])
def testRegridding(self, domain, target):
op = ift.RegriddingOperator(domain, target)
@expand([[ift.RGSpace((13, 52, 40)), (4, 6, 25), None],
[ift.RGSpace((128, 128)), (45, 48), 0],
[ift.RGSpace(13), (7,), None],
[(ift.HPSpace(3), ift.RGSpace((12, 24),distances=0.3)),
(12, 12), 1]])
def testRegridding(self, domain, shape, space):
op = ift.RegriddingOperator(domain, shape, space)
ift.extra.consistency_check(op)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment