Commit 1da4ee97 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

get_k_lengths_array returns a Field now

parent 854b9545
Pipeline #20888 passed with stage
in 5 minutes
......@@ -17,13 +17,13 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from __future__ import division
from builtins import range
import numpy as np
from ..field import Field
from ..domain_tuple import DomainTuple
from .endomorphic_operator import EndomorphicOperator
from ..nifty_utilities import cast_iseq_to_tuple
class DiagonalOperator(EndomorphicOperator):
""" NIFTY class for diagonal operators.
......@@ -87,7 +87,7 @@ class DiagonalOperator(EndomorphicOperator):
if nspc > len(set(self._spaces)):
raise ValueError("non-unique space indices")
# if nspc==len(self.diagonal.domain.domains, we could do some optimization
for i, j in enumerate(self._spaces):
for i, j in enumerate(self._spaces):
if diagonal.domain[i] != self._domain[j]:
raise ValueError("domain mismatch")
......
from builtins import range
import numpy as np
from .endomorphic_operator import EndomorphicOperator
from .fft_operator import FFTOperator
from .diagonal_operator import DiagonalOperator
from .. import DomainTuple
class FFTSmoothingOperator(EndomorphicOperator):
def __init__(self, domain, sigma, space=None):
super(FFTSmoothingOperator, self).__init__()
......@@ -16,25 +15,28 @@ class FFTSmoothingOperator(EndomorphicOperator):
raise ValueError("need a Field with exactly one domain")
space = 0
space = int(space)
if (space<0) or space>=len(dom.domains):
if space < 0 or space >= len(dom.domains):
raise ValueError("space index out of range")
self._space = space
self._transformator = FFTOperator(dom, space=space)
codomain = self._transformator.domain[space].get_default_codomain()
self._kernel = codomain.get_k_length_array()
self._FFT = FFTOperator(dom, space=space)
codomain = self._FFT.domain[space].get_default_codomain()
kernel = codomain.get_k_length_array()
smoother = codomain.get_fft_smoothing_kernel_function(self._sigma)
self._kernel = smoother(self._kernel)
kernel = smoother(kernel)
ddom = list(dom)
ddom[space] = codomain
self._diag = DiagonalOperator(kernel, ddom, space)
def _times(self, x):
if self._sigma == 0:
return x.copy()
return self._smooth(x)
return self._FFT.adjoint_times(self._diag(self._FFT(x)))
@property
def domain(self):
return self._transformator.domain
return self._FFT.domain
@property
def self_adjoint(self):
......@@ -43,20 +45,3 @@ class FFTSmoothingOperator(EndomorphicOperator):
@property
def unitary(self):
return False
def _smooth(self, x):
# transform to the (global-)default codomain and perform all remaining
# steps therein
transformed_x = self._transformator(x)
coaxes = transformed_x.domain.axes[self._space]
# now, apply the kernel to transformed_x
# this is done node-locally utilizing numpy's reshaping in order to
# apply the kernel to the correct axes
reshaper = [transformed_x.shape[i] if i in coaxes else 1
for i in range(len(transformed_x.shape))]
transformed_x *= np.reshape(self._kernel, reshaper)
return self._transformator.adjoint_times(transformed_x)
......@@ -21,7 +21,6 @@ from ..field import Field
from ..spaces.power_space import PowerSpace
from .endomorphic_operator import EndomorphicOperator
from .. import DomainTuple
from .. import nifty_utilities as utilities
class LaplaceOperator(EndomorphicOperator):
......
......@@ -19,6 +19,8 @@
from __future__ import division
import numpy as np
from .space import Space
from .. import Field
from ..basic_arithmetics import exp
class LMSpace(Space):
......@@ -96,7 +98,7 @@ class LMSpace(Space):
for l in range(1, lmax+1):
ldist[idx:idx+2*(lmax+1-l)] = tmp[2*l:]
idx += 2*(lmax+1-l)
return ldist
return Field((self,), ldist)
def get_unique_k_lengths(self):
return np.arange(self.lmax+1, dtype=np.float64)
......@@ -106,7 +108,7 @@ class LMSpace(Space):
res = x+1.
res *= x
res *= -0.5*sigma*sigma
np.exp(res, out=res)
exp(res, out=res)
return res
def get_fft_smoothing_kernel_function(self, sigma):
......
......@@ -18,7 +18,6 @@
import numpy as np
from .space import Space
from .. import dobj
class PowerSpace(Space):
......@@ -144,7 +143,7 @@ class PowerSpace(Space):
temp_rho = np.bincount(temp_pindex.ravel())
assert not np.any(temp_rho == 0), "empty bins detected"
temp_k_lengths = np.bincount(temp_pindex.ravel(),
weights=k_length_array.ravel()) \
weights=k_length_array.val.ravel()) \
/ temp_rho
temp_dvol = temp_rho*pdvol
self._powerIndexCache[key] = (binbounds,
......@@ -160,7 +159,7 @@ class PowerSpace(Space):
if binbounds is None:
tmp = harmonic_partner.get_unique_k_lengths()
binbounds = 0.5*(tmp[:-1]+tmp[1:])
return np.searchsorted(binbounds, k_length_array)
return np.searchsorted(binbounds, k_length_array.val)
# ---Mandatory properties and methods---
......
......@@ -21,6 +21,8 @@ from builtins import range
from functools import reduce
import numpy as np
from .space import Space
from .. import Field
from ..basic_arithmetics import exp
class RGSpace(Space):
......@@ -78,14 +80,14 @@ class RGSpace(Space):
res = np.arange(self.shape[0], dtype=np.float64)
res = np.minimum(res, self.shape[0]-res)*self.distances[0]
if len(self.shape) == 1:
return res
return Field((self,), res)
res *= res
for i in range(1, len(self.shape)):
tmp = np.arange(self.shape[i], dtype=np.float64)
tmp = np.minimum(tmp, self.shape[i]-tmp)*self.distances[i]
tmp *= tmp
res = np.add.outer(res, tmp)
return np.sqrt(res)
return Field((self,), np.sqrt(res))
def get_unique_k_lengths(self):
if (not self.harmonic):
......@@ -107,7 +109,7 @@ class RGSpace(Space):
tmp[t2] = True
return np.sqrt(np.nonzero(tmp)[0])*self.distances[0]
else: # do it the hard way
tmp = np.unique(self.get_k_length_array()) # expensive!
tmp = np.unique(self.get_k_length_array().val) # expensive!
tol = 1e-12*tmp[-1]
# remove all points that are closer than tol to their right
# neighbors.
......@@ -119,7 +121,7 @@ class RGSpace(Space):
def _kernel(x, sigma):
tmp = x*x
tmp *= -2.*np.pi*np.pi*sigma*sigma
np.exp(tmp, out=tmp)
exp(tmp, out=tmp)
return tmp
def get_fft_smoothing_kernel_function(self, sigma):
......
......@@ -42,7 +42,7 @@ class Space(DomainObject):
Returns
-------
numpy.ndarray
Field
An array containing the k vector lengths
"""
raise NotImplementedError
......
......@@ -17,13 +17,11 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import unittest
import numpy as np
from numpy.testing import assert_, assert_equal
from numpy.testing import assert_
from itertools import product
from types import LambdaType
from test.common import expand, generate_spaces, generate_harmonic_spaces
import nifty2go as ift
from nifty2go.spaces import *
......@@ -37,7 +35,7 @@ class SpaceInterfaceTests(unittest.TestCase):
attr_expected_type[1]))
@expand(product(generate_harmonic_spaces(), [
['get_k_length_array', np.ndarray],
['get_k_length_array', ift.Field],
['get_fft_smoothing_kernel_function', 2.0, LambdaType],
]))
def test_method_ret_type(self, space, method_expected_type):
......
......@@ -21,7 +21,7 @@ import unittest
import numpy as np
from numpy.testing import assert_, assert_equal, assert_raises,\
assert_almost_equal, assert_array_almost_equal
assert_almost_equal
from nifty2go import LMSpace
from test.common import expand
......@@ -93,4 +93,4 @@ class LMSpaceFunctionalityTests(unittest.TestCase):
@expand(get_k_length_array_configs())
def test_k_length_array(self, lmax, expected):
l = LMSpace(lmax)
assert_almost_equal(l.get_k_length_array(), expected)
assert_almost_equal(l.get_k_length_array().val, expected)
......@@ -21,12 +21,9 @@ from __future__ import division
import unittest
import numpy as np
from numpy.testing import assert_, assert_equal, assert_almost_equal, \
assert_array_equal
from numpy.testing import assert_, assert_equal, assert_almost_equal
from nifty2go import RGSpace
from test.common import expand
from itertools import product
from nose.plugins.skip import SkipTest
# [shape, distances, harmonic, expected]
CONSTRUCTOR_CONFIGS = [
......@@ -115,7 +112,7 @@ class RGSpaceFunctionalityTests(unittest.TestCase):
@expand(get_k_length_array_configs())
def test_k_length_array(self, shape, distances, expected):
r = RGSpace(shape=shape, distances=distances, harmonic=True)
assert_almost_equal(r.get_k_length_array(), expected)
assert_almost_equal(r.get_k_length_array().val, expected)
@expand(get_dvol_configs())
def test_dvol(self, shape, distances, harmonic, power):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment