Planned maintenance on Wednesday, 2021-01-20, 17:00-18:00. Expect some interruptions during that time

Commit c016b764 authored by Martin Reinecke's avatar Martin Reinecke

tweaks

parent cc20ae5c
Pipeline #71072 failed with stages
in 7 minutes and 57 seconds
import nifty6 as ift
dom=ift.TMP_Domain.make(ift.RGSpace(10,harmonic=True))
dom=ift.TMP_Domain.make(ift.LMSpace(10,10))
op = ift.HarmonicTransformOperator(dom)
bla=ift.TMP_Field.full(dom,1.)
op.inverse(op(bla)).val
op.adjoint(op(bla)).val
......@@ -29,6 +29,19 @@ from .linear_operator import LinearOperator
from .scaling_operator import ScalingOperator
from ..partial_domain import PartialDomain
def extract_single(x, subdom, target):
dct = x.to_dict()
name = subdom.extractName()
fld = x[name]
tgt = target[name]
return fld, tgt, ispc, dct
def reassemble(fld, dct, subdom, target):
name = subdom.extractName()
dct[name] = fld
return TMP_Field.from_dict(dct, target)
class FFTOperator(LinearOperator):
"""Transforms between a pair of position and harmonic RGSpaces.
......@@ -72,30 +85,22 @@ class FFTOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
name = self._space.extractName()
ispc = self._space.extractSpaceIndex()
spc = self._space.extractSpace(self._domain)
x = x.to_dict()
fld = x[name]
ncells = spc.size
fld, tgt, ispc, dct = extract_single(x, self._space, self._tgt(mode))
ncells = fld.domain[ispc].size
if spc.harmonic: # harmonic -> position
func = fft.fftn
fct = 1.
func, fct = fft.fftn, 1.
else:
func = fft.ifftn
fct = ncells
func, fct = fft.ifftn, ncells
axes = fld.domain.axes[ispc]
tdom = self._tgt(mode)[name]
tmp = func(fld.val, axes=axes)
Tval = TMP_fld(tdom, tmp)
Tval = TMP_fld(tgt, tmp)
if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
fct *= self._space.extractSpace(self._domain).scalar_dvol
fct *= fld.domain[ispc].scalar_dvol
else:
fct *= self._space.extractSpace(self._target).scalar_dvol
fct *= tgt[ispc].scalar_dvol
if fct != 1:
Tval = Tval*fct
x[self._space.extractName()] = Tval
return TMP_Field.from_dict(x)
return reassemble (Tval, dct, self._space, mode)
class HartleyOperator(LinearOperator):
......@@ -145,27 +150,21 @@ class HartleyOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
name = self._space.extractName()
ispc = self._space.extractSpaceIndex()
spc = self._space.extractSpace(self._domain)
x = x.to_dict()
fld = x[name]
fld, tgt, ispc, dct = extract_single(x, self._space, self._tgt(mode))
if utilities.iscomplextype(fld.dtype):
x[name] = (self._apply_cartesian(fld.real, mode) +
1j*self._apply_cartesian(fld.imag, mode))
tval = (self._apply_cartesian(fld.real, ispc, tgt, mode) +
1j*self._apply_cartesian(fld.imag, ispc, tgt, mode))
else:
x[name] = self._apply_cartesian(fld, mode)
tval = self._apply_cartesian(fld, ispc, tgt, mode)
return reassemble(tval, dct, self._space, mode)
def _apply_cartesian(self, x, mode):
axes = x.domain.axes[self._space]
tdom = self._tgt(mode)
tmp = fft.hartley(x.val, axes=axes)
Tval = TMP_fld(tdom, tmp)
def _apply_cartesian(self, x, ispc, tgt, mode):
tval = TMP_fld(tgt, tmp)
if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
fct = self._domain[self._space].scalar_dvol
fct = fld.domain[ispc].scalar_dvol
else:
fct = self._target[self._space].scalar_dvol
return Tval if fct == 1 else Tval*fct
fct = tgt[ispc].scalar_dvol
return tval if fct == 1 else tval*fct
class SHTOperator(LinearOperator):
......@@ -180,16 +179,16 @@ class SHTOperator(LinearOperator):
Parameters
----------
domain : TMP_Space, tuple of TMP_Space or TMP_Domain
domain : Domainoid
The domain of the data that is input by "times" and output by
"adjoint_times".
target : TMP_Space, optional
The target domain of the transform operation.
The target (sub-)domain of the transform operation.
If omitted, a domain will be chosen automatically.
Whenever the input domain of the transform is an RGSpace, the codomain
(and its parameters) are uniquely determined.
For LMSpace, a GLSpace of sufficient resolution is chosen.
space : int, optional
space : Subdomainoid
The index of the domain on which the operator should act
If None, it is set to 0 if domain contains exactly one subdomain.
domain[space] must be a LMSpace.
......@@ -199,17 +198,15 @@ class SHTOperator(LinearOperator):
# Initialize domain and target
self._domain = TMP_Domain.make(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES
self._space = utilities.infer_space(self._domain, space)
self._space = PartialDomain(self._domain, space)
hspc = self._domain[self._space]
hspc = self._space.extractSpace(self._domain)
if not isinstance(hspc, LMSpace):
raise TypeError("SHTOperator only works on a LMSpace domain")
if target is None:
target = hspc.get_default_codomain()
self._target = [dom for dom in self._domain]
self._target[self._space] = target
self._target = TMP_Domain.make(self._target)
self._target = self._space.replaceSpace(self._domain,target)
hspc.check_codomain(target)
target.check_codomain(hspc)
......@@ -229,11 +226,15 @@ class SHTOperator(LinearOperator):
def apply(self, x, mode):
self._check_input(x, mode)
if utilities.iscomplextype(x.dtype):
return (self._apply_spherical(x.real, mode) +
1j*self._apply_spherical(x.imag, mode))
name = self._space.extractName()
x = x.to_dict()
fld = x[name]
if utilities.iscomplextype(fld.dtype):
x[name] = (self._apply_spherical(fld.real, mode) +
1j*self._apply_spherical(fld.imag, mode))
else:
return self._apply_spherical(x, mode)
x[name] = self._apply_spherical(fld, mode)
return TMP_Field.from_dict(x, self._tgt(mode))
def _slice_p2h(self, inp):
rr = self.sjob.alm2map_adjoint(inp)
......@@ -257,14 +258,17 @@ class SHTOperator(LinearOperator):
res = self.sjob.alm2map(res)
return res/np.sqrt(np.pi*4)
def _apply_spherical(self, x, mode):
axes = x.domain.axes[self._space]
v = x.val
def _apply_spherical(self, fld, mode):
name = self._space.extractName()
ispc = self._space.extractSpaceIndex()
spc = self._space.extractSpace(self._domain)
axes = fld.domain.axes[ispc]
v = fld.val
p2h = not x.domain[self._space].harmonic
tdom = self._tgt(mode)
p2h = not fld.domain[ispc].harmonic
tdom = self._tgt(mode)[name]
func = self._slice_p2h if p2h else self._slice_h2p
odat = np.empty(tdom.shape, dtype=x.dtype)
odat = np.empty(tdom.shape, dtype=fld.dtype)
for slice in utilities.get_slice_list(v.shape, axes):
odat[slice] = func(v[slice])
return TMP_fld(tdom, odat)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment