Commit eaf710b7 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

Merge branch 'flexible_interpolator' into 'NIFTy_5'

Flexible interpolator

See merge request ift/nifty-dev!139
parents f97a9502 5e047bf0
...@@ -18,28 +18,30 @@ ...@@ -18,28 +18,30 @@
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
from ..compat import * import numpy as np
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import aslinearoperator
from .. import Field, UnstructuredDomain from ..compat import *
from ..domains.rg_space import RGSpace
from ..domains.unstructured_domain import UnstructuredDomain
from ..field import Field
from ..sugar import makeDomain from ..sugar import makeDomain
from .linear_operator import LinearOperator from .linear_operator import LinearOperator
from numpy import (array, prod, mgrid, int64, arange, ravel_multi_index, zeros,
abs, ravel)
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import aslinearoperator
class LinearInterpolator(LinearOperator): class LinearInterpolator(LinearOperator):
def __init__(self, domain, positions): def __init__(self, domain, positions):
""" """
Multilinear interpolation for points in an RGSpace
:param domain: :param domain:
RGSpace RGSpace
:param target:
UnstructuredDomain, shape (ndata,)
:param positions: :param positions:
positions at which to interpolate positions at which to interpolate
Field with UnstructuredDomain, shape (dim, ndata) Field with UnstructuredDomain, shape (dim, ndata)
positions that are not within the RGSpace are wrapped
according to periodic boundary conditions
""" """
self._domain = makeDomain(domain) self._domain = makeDomain(domain)
N_points = positions.shape[1] N_points = positions.shape[1]
...@@ -49,31 +51,38 @@ class LinearInterpolator(LinearOperator): ...@@ -49,31 +51,38 @@ class LinearInterpolator(LinearOperator):
def _build_mat(self, positions, N_points): def _build_mat(self, positions, N_points):
ndim = positions.shape[0] ndim = positions.shape[0]
mg = mgrid[(slice(0, 2),)*ndim] mg = np.mgrid[(slice(0, 2),)*ndim]
mg = array(list(map(ravel, mg))) mg = np.array(list(map(np.ravel, mg)))
dist = array(self.domain[0].distances).reshape((-1, 1)) dist = []
for dom in self.domain:
if not isinstance(dom, RGSpace):
raise TypeError
dist.append(list(dom.distances))
dist = np.array(dist).reshape(-1, 1)
pos = positions/dist pos = positions/dist
excess = pos-pos.astype(int64) excess = pos-pos.astype(np.int64)
pos = pos.astype(int64) pos = pos.astype(np.int64)
data = zeros((len(mg[0]), N_points)) max_index = np.array(self.domain.shape).reshape(-1, 1)
ii = zeros((len(mg[0]), N_points), dtype=int64) data = np.zeros((len(mg[0]), N_points))
jj = zeros((len(mg[0]), N_points), dtype=int64) ii = np.zeros((len(mg[0]), N_points), dtype=np.int64)
jj = np.zeros((len(mg[0]), N_points), dtype=np.int64)
for i in range(len(mg[0])): for i in range(len(mg[0])):
factor = prod(abs(1-mg[:, i].reshape((-1, 1))-excess), axis=0) factor = np.prod(np.abs(1-mg[:, i].reshape(-1, 1)-excess),
axis=0)
data[i, :] = factor data[i, :] = factor
fromi = pos+mg[:, i].reshape((-1, 1)) fromi = (pos+mg[:, i].reshape(-1, 1)) % max_index
ii[i, :] = arange(N_points) ii[i, :] = np.arange(N_points)
jj[i, :] = ravel_multi_index(fromi, self.domain.shape) jj[i, :] = np.ravel_multi_index(fromi, self.domain.shape)
self._mat = coo_matrix((data.reshape(-1), self._mat = coo_matrix((data.reshape(-1),
(ii.reshape(-1), jj.reshape(-1))), (ii.reshape(-1), jj.reshape(-1))),
(N_points, prod(self.domain.shape))) (N_points, np.prod(self.domain.shape)))
self._mat = aslinearoperator(self._mat) self._mat = aslinearoperator(self._mat)
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
x_val = x.to_global_data() x_val = x.to_global_data()
if mode == self.TIMES: if mode == self.TIMES:
res = self._mat.matvec(x_val.reshape((-1,))) res = self._mat.matvec(x_val.reshape(-1))
return Field.from_global_data(self.target, res) return Field.from_global_data(self.target, res)
res = self._mat.rmatvec(x_val).reshape(self.domain.shape) res = self._mat.rmatvec(x_val).reshape(self.domain.shape)
return Field.from_global_data(self.domain, res) return Field.from_global_data(self.domain, res)
......
...@@ -131,6 +131,9 @@ class GeometryRemover(LinearOperator): ...@@ -131,6 +131,9 @@ class GeometryRemover(LinearOperator):
---------- ----------
domain: Domain, tuple of Domain, or DomainTuple: domain: Domain, tuple of Domain, or DomainTuple:
the full input domain of the operator. the full input domain of the operator.
space: int, optional
The index of the subdomain on which the operator should act
If None, it acts on all spaces
Notes Notes
----- -----
...@@ -139,10 +142,14 @@ class GeometryRemover(LinearOperator): ...@@ -139,10 +142,14 @@ class GeometryRemover(LinearOperator):
is carried out. is carried out.
""" """
def __init__(self, domain): def __init__(self, domain, space=None):
self._domain = DomainTuple.make(domain) self._domain = DomainTuple.make(domain)
target_list = [UnstructuredDomain(dom.shape) for dom in self._domain] if space is not None:
self._target = DomainTuple.make(target_list) tgt = [dom for dom in self._domain]
tgt[space] = UnstructuredDomain(self._domain[space].shape)
else:
tgt = [UnstructuredDomain(dom.shape) for dom in self._domain]
self._target = DomainTuple.make(tgt)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
def apply(self, x, mode): def apply(self, x, mode):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment