Commit 92055ad4 authored by Philipp Arras's avatar Philipp Arras
Browse files

Add docu and type checks

parent 750bafe7
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
# #
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik. # NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik.
from functools import reduce
from operator import mul
import numpy as np import numpy as np
from scipy.sparse import coo_matrix from scipy.sparse import coo_matrix
from scipy.sparse.linalg import aslinearoperator from scipy.sparse.linalg import aslinearoperator
...@@ -27,50 +30,69 @@ from .linear_operator import LinearOperator ...@@ -27,50 +30,69 @@ from .linear_operator import LinearOperator
class LinearInterpolator(LinearOperator): class LinearInterpolator(LinearOperator):
def __init__(self, domain, positions): def __init__(self, domain, sampling_points):
""" """
Multilinear interpolation for points in an RGSpace Multilinear interpolation for points in an RGSpace
:param domain: Parameters
RGSpace ----------
:param positions: domain : RGSpace
positions at which to interpolate positions : numpy.ndarray
Positions at which to interpolate
Field with UnstructuredDomain, shape (dim, ndata) Field with UnstructuredDomain, shape (dim, ndata)
positions that are not within the RGSpace are wrapped
according to periodic boundary conditions Notes
-----
Positions that are not within the RGSpace are wrapped according to
periodic boundary conditions. This reflects the general property of
RGSpaces to be tori topologically.
""" """
self._domain = makeDomain(domain) self._domain = makeDomain(domain)
N_points = positions.shape[1] for dom in self.domain:
self._target = makeDomain(UnstructuredDomain(N_points)) if not isinstance(dom, RGSpace):
raise TypeError
dims = [len(dom.shape) for dom in self.domain]
# FIXME This needs to be removed as soon as the bug below is fixed.
if not dims.count(dims[0]) == len(dims):
raise TypeError(
'This is a bug. Please extend LinearInterpolators functionality!'
)
shp = sampling_points.shape
if not (isinstance(sampling_points, np.ndarray) and len(shp) == 2):
raise TypeError
n_dim, n_points = shp
if not n_dim == reduce(mul, dims):
raise TypeError
self._target = makeDomain(UnstructuredDomain(n_points))
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
self._build_mat(positions, N_points) self._build_mat(sampling_points, n_points)
def _build_mat(self, positions, N_points): def _build_mat(self, sampling_points, N_points):
ndim = positions.shape[0] ndim = sampling_points.shape[0]
mg = np.mgrid[(slice(0, 2),)*ndim] mg = np.mgrid[(slice(0, 2),)*ndim]
mg = np.array(list(map(np.ravel, mg))) mg = np.array(list(map(np.ravel, mg)))
dist = [] dist = [list(dom.distances) for dom in self.domain]
for dom in self.domain: # FIXME This breaks as soon as not all domains have the same number of
if not isinstance(dom, RGSpace): # dimensions.
raise TypeError
dist.append(list(dom.distances))
dist = np.array(dist).reshape(-1, 1) dist = np.array(dist).reshape(-1, 1)
pos = positions/dist pos = sampling_points/dist
excess = pos-pos.astype(np.int64) excess = pos - pos.astype(np.int64)
pos = pos.astype(np.int64) pos = pos.astype(np.int64)
max_index = np.array(self.domain.shape).reshape(-1, 1) max_index = np.array(self.domain.shape).reshape(-1, 1)
data = np.zeros((len(mg[0]), N_points)) data = np.zeros((len(mg[0]), N_points))
ii = np.zeros((len(mg[0]), N_points), dtype=np.int64) ii = np.zeros((len(mg[0]), N_points), dtype=np.int64)
jj = np.zeros((len(mg[0]), N_points), dtype=np.int64) jj = np.zeros((len(mg[0]), N_points), dtype=np.int64)
for i in range(len(mg[0])): for i in range(len(mg[0])):
factor = np.prod(np.abs(1-mg[:, i].reshape(-1, 1)-excess), factor = np.prod(
axis=0) np.abs(1 - mg[:, i].reshape(-1, 1) - excess), axis=0)
data[i, :] = factor data[i, :] = factor
fromi = (pos+mg[:, i].reshape(-1, 1)) % max_index fromi = (pos + mg[:, i].reshape(-1, 1)) % max_index
ii[i, :] = np.arange(N_points) ii[i, :] = np.arange(N_points)
jj[i, :] = np.ravel_multi_index(fromi, self.domain.shape) jj[i, :] = np.ravel_multi_index(fromi, self.domain.shape)
self._mat = coo_matrix((data.reshape(-1), self._mat = coo_matrix((data.reshape(-1),
(ii.reshape(-1), jj.reshape(-1))), (ii.reshape(-1), jj.reshape(-1))),
(N_points, np.prod(self.domain.shape))) (N_points, np.prod(self.domain.shape)))
self._mat = aslinearoperator(self._mat) self._mat = aslinearoperator(self._mat)
...@@ -79,61 +101,6 @@ class LinearInterpolator(LinearOperator): ...@@ -79,61 +101,6 @@ class LinearInterpolator(LinearOperator):
x_val = x.to_global_data() x_val = x.to_global_data()
if mode == self.TIMES: if mode == self.TIMES:
res = self._mat.matvec(x_val.reshape(-1)) res = self._mat.matvec(x_val.reshape(-1))
return Field.from_global_data(self.target, res) else:
res = self._mat.rmatvec(x_val).reshape(self.domain.shape) res = self._mat.rmatvec(x_val).reshape(self.domain.shape)
return Field.from_global_data(self.domain, res) return Field.from_global_data(self._tgt(mode), res)
# import numpy as np
# from ..domains.rg_space import RGSpace
# import itertools
#
#
# class LinearInterpolator(LinearOperator):
# def __init__(self, domain, positions):
# """
# :param domain:
# RGSpace
# :param target:
# UnstructuredDomain, shape (ndata,)
# :param positions:
# positions at which to interpolate
# Field with UnstructuredDomain, shape (dim, ndata)
# """
# if not isinstance(domain, RGSpace):
# raise TypeError("RGSpace needed")
# if np.any(domain.shape < 2):
# raise ValueError("RGSpace shape too small")
# if positions.ndim != 2:
# raise ValueError("positions must be a 2D array")
# ndim = len(domain.shape)
# if positions.shape[0] != ndim:
# raise ValueError("shape mismatch")
# self._domain = makeDomain(domain)
# N_points = positions.shape[1]
# dist = np.array(domain.distances).reshape((ndim, -1))
# self._pos = positions/dist
# shp = np.array(domain.shape, dtype=int).reshape((ndim, -1))
# self._idx = np.maximum(0, np.minimum(shp-2, self._pos.astype(int)))
# self._pos -= self._idx
# tmp = tuple([0, 1] for i in range(ndim))
# self._corners = np.array(list(itertools.product(*tmp)))
# self._target = makeDomain(UnstructuredDomain(N_points))
# self._capability = self.TIMES | self.ADJOINT_TIMES
#
# def apply(self, x, mode):
# self._check_input(x, mode)
# x = x.to_global_data()
# ndim = len(self._domain.shape)
#
# res = np.zeros(self._tgt(mode).shape, dtype=x.dtype)
# for corner in self._corners:
# corner = corner.reshape(ndim, -1)
# idx = self._idx+corner
# idx2 = tuple(idx[t, :] for t in range(idx.shape[0]))
# wgt = np.prod(self._pos*corner+(1-self._pos)*(1-corner), axis=0)
# if mode == self.TIMES:
# res += wgt*x[idx2]
# else:
# np.add.at(res, idx2, wgt*x)
# return Field.from_global_data(self._tgt(mode), res)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment