Commit eaf710b7 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'flexible_interpolator' into 'NIFTy_5'

Flexible interpolator

See merge request ift/nifty-dev!139
parents f97a9502 5e047bf0
......@@ -18,28 +18,30 @@
from __future__ import absolute_import, division, print_function
from ..compat import *
import numpy as np
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import aslinearoperator
from .. import Field, UnstructuredDomain
from ..compat import *
from ..domains.rg_space import RGSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..sugar import makeDomain
from .linear_operator import LinearOperator
from numpy import (array, prod, mgrid, int64, arange, ravel_multi_index, zeros,
abs, ravel)
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import aslinearoperator
class LinearInterpolator(LinearOperator):
def __init__(self, domain, positions):
"""
Multilinear interpolation for points in an RGSpace
:param domain:
RGSpace
:param target:
UnstructuredDomain, shape (ndata,)
:param positions:
positions at which to interpolate
Field with UnstructuredDomain, shape (dim, ndata)
positions that are not within the RGSpace are wrapped
according to periodic boundary conditions
"""
self._domain = makeDomain(domain)
N_points = positions.shape[1]
......@@ -49,31 +51,38 @@ class LinearInterpolator(LinearOperator):
def _build_mat(self, positions, N_points):
ndim = positions.shape[0]
mg = mgrid[(slice(0, 2),)*ndim]
mg = array(list(map(ravel, mg)))
dist = array(self.domain[0].distances).reshape((-1, 1))
mg = np.mgrid[(slice(0, 2),)*ndim]
mg = np.array(list(map(np.ravel, mg)))
dist = []
for dom in self.domain:
if not isinstance(dom, RGSpace):
raise TypeError
dist.append(list(dom.distances))
dist = np.array(dist).reshape(-1, 1)
pos = positions/dist
excess = pos-pos.astype(int64)
pos = pos.astype(int64)
data = zeros((len(mg[0]), N_points))
ii = zeros((len(mg[0]), N_points), dtype=int64)
jj = zeros((len(mg[0]), N_points), dtype=int64)
excess = pos-pos.astype(np.int64)
pos = pos.astype(np.int64)
max_index = np.array(self.domain.shape).reshape(-1, 1)
data = np.zeros((len(mg[0]), N_points))
ii = np.zeros((len(mg[0]), N_points), dtype=np.int64)
jj = np.zeros((len(mg[0]), N_points), dtype=np.int64)
for i in range(len(mg[0])):
factor = prod(abs(1-mg[:, i].reshape((-1, 1))-excess), axis=0)
factor = np.prod(np.abs(1-mg[:, i].reshape(-1, 1)-excess),
axis=0)
data[i, :] = factor
fromi = pos+mg[:, i].reshape((-1, 1))
ii[i, :] = arange(N_points)
jj[i, :] = ravel_multi_index(fromi, self.domain.shape)
fromi = (pos+mg[:, i].reshape(-1, 1)) % max_index
ii[i, :] = np.arange(N_points)
jj[i, :] = np.ravel_multi_index(fromi, self.domain.shape)
self._mat = coo_matrix((data.reshape(-1),
(ii.reshape(-1), jj.reshape(-1))),
(N_points, prod(self.domain.shape)))
(N_points, np.prod(self.domain.shape)))
self._mat = aslinearoperator(self._mat)
def apply(self, x, mode):
self._check_input(x, mode)
x_val = x.to_global_data()
if mode == self.TIMES:
res = self._mat.matvec(x_val.reshape((-1,)))
res = self._mat.matvec(x_val.reshape(-1))
return Field.from_global_data(self.target, res)
res = self._mat.rmatvec(x_val).reshape(self.domain.shape)
return Field.from_global_data(self.domain, res)
......
......@@ -131,6 +131,9 @@ class GeometryRemover(LinearOperator):
----------
domain: Domain, tuple of Domain, or DomainTuple:
the full input domain of the operator.
space: int, optional
The index of the subdomain on which the operator should act
If None, it acts on all spaces
Notes
-----
......@@ -139,10 +142,14 @@ class GeometryRemover(LinearOperator):
is carried out.
"""
def __init__(self, domain):
def __init__(self, domain, space=None):
self._domain = DomainTuple.make(domain)
target_list = [UnstructuredDomain(dom.shape) for dom in self._domain]
self._target = DomainTuple.make(target_list)
if space is not None:
tgt = [dom for dom in self._domain]
tgt[space] = UnstructuredDomain(self._domain[space].shape)
else:
tgt = [UnstructuredDomain(dom.shape) for dom in self._domain]
self._target = DomainTuple.make(tgt)
self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment