Commit 1f7427e1 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

convenience method for sanitizing the space keyword

parent 6d44c94b
Pipeline #22911 passed with stage
in 4 minutes and 41 seconds
import numpy as np
from .linear_operator import LinearOperator
from ..utilities import infer_space
from .. import Field, DomainTuple, dobj
from ..spaces import DOFSpace
......@@ -11,11 +12,7 @@ class DOFProjectionOperator(LinearOperator):
if domain is None:
domain = dofdex.domain
self._domain = DomainTuple.make(domain)
if space is None and len(self._domain) == 1:
space = 0
space = int(space)
if space < 0 or space >= len(self.domain):
raise ValueError("space index out of range")
space = infer_space(self._domain, space)
partner = self._domain[space]
if not isinstance(dofdex, Field):
raise TypeError("dofdex must be a Field")
......
......@@ -19,6 +19,7 @@
import numpy as np
from .. import DomainTuple
from ..spaces import RGSpace
from ..utilities import infer_space
from .linear_operator import LinearOperator
from .fft_operator_support import RGRGTransformation, SphericalTransformation
......@@ -77,14 +78,7 @@ class FFTOperator(LinearOperator):
# Initialize domain and target
self._domain = DomainTuple.make(domain)
if space is None:
if len(self._domain) != 1:
raise ValueError("need a Field with exactly one domain")
space = 0
space = int(space)
if space < 0 or space >= len(self._domain):
raise ValueError("space index out of range")
self._space = space
self._space = infer_space(self._domain, space)
adom = self.domain[self._space]
if target is None:
......@@ -96,14 +90,14 @@ class FFTOperator(LinearOperator):
adom.check_codomain(target)
target.check_codomain(adom)
if self._target[space].harmonic:
if self._target[self._space].harmonic:
pdom, hdom = (self._domain, self._target)
else:
pdom, hdom = (self._target, self._domain)
if isinstance(pdom[space], RGSpace):
self._trafo = RGRGTransformation(pdom, hdom, space)
if isinstance(pdom[self._space], RGSpace):
self._trafo = RGRGTransformation(pdom, hdom, self._space)
else:
self._trafo = SphericalTransformation(pdom, hdom, space)
self._trafo = SphericalTransformation(pdom, hdom, self._space)
def _times_helper(self, x):
if np.issubdtype(x.dtype, np.complexfloating):
......
from .endomorphic_operator import EndomorphicOperator
from .fft_operator import FFTOperator
from ..utilities import infer_space
from .diagonal_operator import DiagonalOperator
from .. import DomainTuple
......@@ -10,23 +11,16 @@ class FFTSmoothingOperator(EndomorphicOperator):
dom = DomainTuple.make(domain)
self._sigma = float(sigma)
if space is None:
if len(dom) != 1:
raise ValueError("need a Field with exactly one domain")
space = 0
space = int(space)
if space < 0 or space >= len(dom):
raise ValueError("space index out of range")
self._space = space
self._FFT = FFTOperator(dom, space=space)
codomain = self._FFT.domain[space].get_default_codomain()
self._space = infer_space(dom, space)
self._FFT = FFTOperator(dom, space=self._space)
codomain = self._FFT.domain[self._space].get_default_codomain()
kernel = codomain.get_k_length_array()
smoother = codomain.get_fft_smoothing_kernel_function(self._sigma)
kernel = smoother(kernel)
ddom = list(dom)
ddom[space] = codomain
self._diag = DiagonalOperator(kernel, ddom, space)
ddom[self._space] = codomain
self._diag = DiagonalOperator(kernel, ddom, self._space)
def _times(self, x):
if self._sigma == 0:
......
......@@ -20,6 +20,7 @@ import numpy as np
from ..field import Field
from ..spaces.power_space import PowerSpace
from .endomorphic_operator import EndomorphicOperator
from ..utilities import infer_space
from .. import DomainTuple
from .. import dobj
......@@ -44,14 +45,7 @@ class LaplaceOperator(EndomorphicOperator):
def __init__(self, domain, space=None, logarithmic=True):
super(LaplaceOperator, self).__init__()
self._domain = DomainTuple.make(domain)
if space is None:
if len(self._domain) != 1:
raise ValueError("need a Field with exactly one domain")
space = 0
space = int(space)
if space < 0 or space >= len(self._domain):
raise ValueError("space index out of range")
self._space = space
self._space = infer_space(self._domain, space)
if not isinstance(self._domain[self._space], PowerSpace):
raise ValueError("Operator must act on a PowerSpace.")
......
......@@ -19,6 +19,7 @@
import numpy as np
from .dof_projection_operator import DOFProjectionOperator
from .. import Field, DomainTuple, dobj
from ..utilities import infer_space
from ..spaces import PowerSpace
......@@ -26,12 +27,8 @@ class PowerProjectionOperator(DOFProjectionOperator):
def __init__(self, domain, power_space=None, space=None):
# Initialize domain and target
self._domain = DomainTuple.make(domain)
if space is None and len(self._domain) == 1:
space = 0
space = int(space)
if space < 0 or space >= len(self.domain):
raise ValueError("space index out of range")
hspace = self._domain[space]
self._space = infer_space(self._domain, space)
hspace = self._domain[self._space]
if not hspace.harmonic:
raise ValueError("Operator acts on harmonic spaces only")
if power_space is None:
......@@ -42,4 +39,4 @@ class PowerProjectionOperator(DOFProjectionOperator):
if power_space.harmonic_partner != hspace:
raise ValueError("power_space does not match its partner")
self._init2(power_space.pindex, space, power_space)
self._init2(power_space.pindex, self._space, power_space)
......@@ -73,6 +73,17 @@ def cast_iseq_to_tuple(seq):
return tuple(int(item) for item in seq)
def infer_space(domain, space):
if space is None:
if len(domain) != 1:
raise ValueError("need a Field with exactly one domain")
space = 0
space = int(space)
if space < 0 or space >= len(domain):
raise ValueError("space index out of range")
return space
def memo(f):
name = f.__name__
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment