Commit dc5042f0 authored by Reimar Heinrich Leike's avatar Reimar Heinrich Leike

added an operator doing multilinear interpolation

parent 309dd635
......@@ -24,6 +24,7 @@ from .operators.diagonal_operator import DiagonalOperator
from .operators.distributors import DOFDistributor, PowerDistributor
from .operators.domain_tuple_field_inserter import DomainTupleFieldInserter
from .operators.contraction_operator import ContractionOperator
from .operators.linear_interpolation import LinearInterpolator
from .operators.endomorphic_operator import EndomorphicOperator
from .operators.exp_transform import ExpTransform
from .operators.harmonic_operators import (
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2018 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import absolute_import, division, print_function
from ..compat import *
from .. import Field, UnstructuredDomain
from ..sugar import makeDomain
from .linear_operator import LinearOperator
from numpy import array, prod, mgrid, int64, arange, ravel_multi_index, zeros, abs, ravel
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import aslinearoperator
class LinearInterpolator(LinearOperator):
def __init__(self, domain, positions):
"""
:param domain:
RGSpace
:param target:
UnstructuredDomain, shape (ndata,)
:param positions:
positions at which to interpolate
Field with UnstructuredDomain, shape (dim, ndata)
"""
self._domain = makeDomain(domain)
N_points = positions.shape[1]
self._target = makeDomain(UnstructuredDomain(N_points))
self._capability = self.TIMES | self.ADJOINT_TIMES
self._build_mat(positions, N_points)
def _build_mat(self, positions, N_points):
ndim = positions.shape[0]
mg = mgrid[(slice(0,2),)*ndim]
mg = array(list(map(ravel, mg)))
dist = array(self.domain[0].distances).reshape((-1,1))
pos = positions/dist
excess = pos-pos.astype(int64)
pos = pos.astype(int64)
data = zeros((len(mg[0]), N_points))
ii = zeros((len(mg[0]), N_points), dtype=int64)
jj = zeros((len(mg[0]), N_points), dtype=int64)
for i in range(len(mg[0])):
factor = prod(abs(1-mg[:,i].reshape((-1,1))-excess),axis=0)
#print(factor)
data[i,:] = factor
fromi = pos+mg[:,i].reshape((-1,1))
ii[i, :] = arange(N_points)
jj[i, :] = ravel_multi_index(fromi, self.domain.shape)
self._mat = coo_matrix((data.reshape(-1),
(ii.reshape(-1),jj.reshape(-1))),
(N_points, prod(self.domain.shape)))
self._mat = aslinearoperator(self._mat)
def apply(self, x, mode):
self._check_input(x, mode)
x_val = x.to_global_data()
if mode == self.TIMES:
res = self._mat.matvec(x_val.reshape((-1,)))
return Field.from_global_data(self.target, res)
res = self._mat.rmatvec(x_val).reshape(self.domain.shape)
return Field.from_global_data(self.domain, res)
......@@ -58,12 +58,13 @@ class Consistency_Tests(unittest.TestCase):
op = a+b
ift.extra.consistency_check(op, dtype, dtype)
@expand(product(_h_spaces + _p_spaces + _pow_spaces,
[np.float64, np.complex128]))
def testVdotOperator(self, sp, dtype):
op = ift.VdotOperator(ift.Field.from_random("normal", sp,
dtype=dtype))
ift.extra.consistency_check(op, dtype, dtype)
def testLinearInterpolator(self):
sp = ift.RGSpace((10,8), distances=(0.1, 3.5))
pos = np.random.rand(2, 23)
pos[0,:] *= 0.9
pos[1,:] *= 7*3.5
op = ift.LinearInterpolator(sp, pos)
ift.extra.consistency_check(op)
@expand(product([(ift.RGSpace(10, harmonic=True), 4, 0),
(ift.RGSpace((24, 31), distances=(0.4, 2.34),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment