Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
b35c6926
Commit
b35c6926
authored
Oct 09, 2018
by
Martin Reinecke
Browse files
cosmetics; add a (commented out) alternative version
parent
dc5042f0
Changes
1
Hide whitespace changes
Inline
Side-by-side
nifty5/operators/linear_interpolation.py
View file @
b35c6926
...
...
@@ -20,11 +20,11 @@ from __future__ import absolute_import, division, print_function
from
..compat
import
*
from
..
import
Field
,
UnstructuredDomain
from
..sugar
import
makeDomain
from
..sugar
import
makeDomain
from
.linear_operator
import
LinearOperator
from
numpy
import
array
,
prod
,
mgrid
,
int64
,
arange
,
ravel_multi_index
,
zeros
,
abs
,
ravel
from
numpy
import
(
array
,
prod
,
mgrid
,
int64
,
arange
,
ravel_multi_index
,
zeros
,
abs
,
ravel
)
from
scipy.sparse
import
coo_matrix
from
scipy.sparse.linalg
import
aslinearoperator
...
...
@@ -49,9 +49,9 @@ class LinearInterpolator(LinearOperator):
def
_build_mat
(
self
,
positions
,
N_points
):
ndim
=
positions
.
shape
[
0
]
mg
=
mgrid
[(
slice
(
0
,
2
),)
*
ndim
]
mg
=
mgrid
[(
slice
(
0
,
2
),)
*
ndim
]
mg
=
array
(
list
(
map
(
ravel
,
mg
)))
dist
=
array
(
self
.
domain
[
0
].
distances
).
reshape
((
-
1
,
1
))
dist
=
array
(
self
.
domain
[
0
].
distances
).
reshape
((
-
1
,
1
))
pos
=
positions
/
dist
excess
=
pos
-
pos
.
astype
(
int64
)
pos
=
pos
.
astype
(
int64
)
...
...
@@ -59,17 +59,16 @@ class LinearInterpolator(LinearOperator):
ii
=
zeros
((
len
(
mg
[
0
]),
N_points
),
dtype
=
int64
)
jj
=
zeros
((
len
(
mg
[
0
]),
N_points
),
dtype
=
int64
)
for
i
in
range
(
len
(
mg
[
0
])):
factor
=
prod
(
abs
(
1
-
mg
[:,
i
].
reshape
((
-
1
,
1
))
-
excess
),
axis
=
0
)
#print(factor)
data
[
i
,:]
=
factor
fromi
=
pos
+
mg
[:,
i
].
reshape
((
-
1
,
1
))
factor
=
prod
(
abs
(
1
-
mg
[:,
i
].
reshape
((
-
1
,
1
))
-
excess
),
axis
=
0
)
data
[
i
,
:]
=
factor
fromi
=
pos
+
mg
[:,
i
].
reshape
((
-
1
,
1
))
ii
[
i
,
:]
=
arange
(
N_points
)
jj
[
i
,
:]
=
ravel_multi_index
(
fromi
,
self
.
domain
.
shape
)
self
.
_mat
=
coo_matrix
((
data
.
reshape
(
-
1
),
(
ii
.
reshape
(
-
1
),
jj
.
reshape
(
-
1
))),
(
N_points
,
prod
(
self
.
domain
.
shape
)))
self
.
_mat
=
coo_matrix
((
data
.
reshape
(
-
1
),
(
ii
.
reshape
(
-
1
),
jj
.
reshape
(
-
1
))),
(
N_points
,
prod
(
self
.
domain
.
shape
)))
self
.
_mat
=
aslinearoperator
(
self
.
_mat
)
def
apply
(
self
,
x
,
mode
):
self
.
_check_input
(
x
,
mode
)
x_val
=
x
.
to_global_data
()
...
...
@@ -79,3 +78,57 @@ class LinearInterpolator(LinearOperator):
res
=
self
.
_mat
.
rmatvec
(
x_val
).
reshape
(
self
.
domain
.
shape
)
return
Field
.
from_global_data
(
self
.
domain
,
res
)
# import numpy as np
# from ..domains.rg_space import RGSpace
# import itertools
#
#
# class LinearInterpolator(LinearOperator):
# def __init__(self, domain, positions):
# """
# :param domain:
# RGSpace
# :param target:
# UnstructuredDomain, shape (ndata,)
# :param positions:
# positions at which to interpolate
# Field with UnstructuredDomain, shape (dim, ndata)
# """
# if not isinstance(domain, RGSpace):
# raise TypeError("RGSpace needed")
# if np.any(domain.shape < 2):
# raise ValueError("RGSpace shape too small")
# if positions.ndim != 2:
# raise ValueError("positions must be a 2D array")
# ndim = len(domain.shape)
# if positions.shape[0] != ndim:
# raise ValueError("shape mismatch")
# self._domain = makeDomain(domain)
# N_points = positions.shape[1]
# dist = np.array(domain.distances).reshape((ndim, -1))
# self._pos = positions/dist
# shp = np.array(domain.shape, dtype=int).reshape((ndim, -1))
# self._idx = np.maximum(0, np.minimum(shp-2, self._pos.astype(int)))
# self._pos -= self._idx
# tmp = tuple([0, 1] for i in range(ndim))
# self._corners = np.array(list(itertools.product(*tmp)))
# self._target = makeDomain(UnstructuredDomain(N_points))
# self._capability = self.TIMES | self.ADJOINT_TIMES
#
# def apply(self, x, mode):
# self._check_input(x, mode)
# x = x.to_global_data()
# ndim = len(self._domain.shape)
#
# res = np.zeros(self._tgt(mode).shape, dtype=x.dtype)
# for corner in self._corners:
# corner = corner.reshape(ndim, -1)
# idx = self._idx+corner
# idx2 = tuple(idx[t, :] for t in range(idx.shape[0]))
# wgt = np.prod(self._pos*corner+(1-self._pos)*(1-corner), axis=0)
# if mode == self.TIMES:
# res += wgt*x[idx2]
# else:
# np.add.at(res, idx2, wgt*x)
# return Field.from_global_data(self._tgt(mode), res)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment