Skip to content
Snippets Groups Projects
Commit c016b764 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

tweaks

parent cc20ae5c
No related branches found
No related tags found
No related merge requests found
Pipeline #71072 failed
import nifty6 as ift import nifty6 as ift
dom=ift.TMP_Domain.make(ift.RGSpace(10,harmonic=True)) dom=ift.TMP_Domain.make(ift.LMSpace(10,10))
op = ift.HarmonicTransformOperator(dom) op = ift.HarmonicTransformOperator(dom)
bla=ift.TMP_Field.full(dom,1.) bla=ift.TMP_Field.full(dom,1.)
op.inverse(op(bla)).val op.adjoint(op(bla)).val
...@@ -29,6 +29,19 @@ from .linear_operator import LinearOperator ...@@ -29,6 +29,19 @@ from .linear_operator import LinearOperator
from .scaling_operator import ScalingOperator from .scaling_operator import ScalingOperator
from ..partial_domain import PartialDomain from ..partial_domain import PartialDomain
def extract_single(x, subdom, target):
dct = x.to_dict()
name = subdom.extractName()
fld = x[name]
tgt = target[name]
return fld, tgt, ispc, dct
def reassemble(fld, dct, subdom, target):
name = subdom.extractName()
dct[name] = fld
return TMP_Field.from_dict(dct, target)
class FFTOperator(LinearOperator): class FFTOperator(LinearOperator):
"""Transforms between a pair of position and harmonic RGSpaces. """Transforms between a pair of position and harmonic RGSpaces.
...@@ -72,30 +85,22 @@ class FFTOperator(LinearOperator): ...@@ -72,30 +85,22 @@ class FFTOperator(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
name = self._space.extractName() fld, tgt, ispc, dct = extract_single(x, self._space, self._tgt(mode))
ispc = self._space.extractSpaceIndex() ncells = fld.domain[ispc].size
spc = self._space.extractSpace(self._domain)
x = x.to_dict()
fld = x[name]
ncells = spc.size
if spc.harmonic: # harmonic -> position if spc.harmonic: # harmonic -> position
func = fft.fftn func, fct = fft.fftn, 1.
fct = 1.
else: else:
func = fft.ifftn func, fct = fft.ifftn, ncells
fct = ncells
axes = fld.domain.axes[ispc] axes = fld.domain.axes[ispc]
tdom = self._tgt(mode)[name]
tmp = func(fld.val, axes=axes) tmp = func(fld.val, axes=axes)
Tval = TMP_fld(tdom, tmp) Tval = TMP_fld(tgt, tmp)
if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES): if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
fct *= self._space.extractSpace(self._domain).scalar_dvol fct *= fld.domain[ispc].scalar_dvol
else: else:
fct *= self._space.extractSpace(self._target).scalar_dvol fct *= tgt[ispc].scalar_dvol
if fct != 1: if fct != 1:
Tval = Tval*fct Tval = Tval*fct
x[self._space.extractName()] = Tval return reassemble (Tval, dct, self._space, mode)
return TMP_Field.from_dict(x)
class HartleyOperator(LinearOperator): class HartleyOperator(LinearOperator):
...@@ -145,27 +150,21 @@ class HartleyOperator(LinearOperator): ...@@ -145,27 +150,21 @@ class HartleyOperator(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
name = self._space.extractName() fld, tgt, ispc, dct = extract_single(x, self._space, self._tgt(mode))
ispc = self._space.extractSpaceIndex()
spc = self._space.extractSpace(self._domain)
x = x.to_dict()
fld = x[name]
if utilities.iscomplextype(fld.dtype): if utilities.iscomplextype(fld.dtype):
x[name] = (self._apply_cartesian(fld.real, mode) + tval = (self._apply_cartesian(fld.real, ispc, tgt, mode) +
1j*self._apply_cartesian(fld.imag, mode)) 1j*self._apply_cartesian(fld.imag, ispc, tgt, mode))
else: else:
x[name] = self._apply_cartesian(fld, mode) tval = self._apply_cartesian(fld, ispc, tgt, mode)
return reassemble(tval, dct, self._space, mode)
def _apply_cartesian(self, x, mode): def _apply_cartesian(self, x, ispc, tgt, mode):
axes = x.domain.axes[self._space] tval = TMP_fld(tgt, tmp)
tdom = self._tgt(mode)
tmp = fft.hartley(x.val, axes=axes)
Tval = TMP_fld(tdom, tmp)
if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES): if mode & (LinearOperator.TIMES | LinearOperator.ADJOINT_TIMES):
fct = self._domain[self._space].scalar_dvol fct = fld.domain[ispc].scalar_dvol
else: else:
fct = self._target[self._space].scalar_dvol fct = tgt[ispc].scalar_dvol
return Tval if fct == 1 else Tval*fct return tval if fct == 1 else tval*fct
class SHTOperator(LinearOperator): class SHTOperator(LinearOperator):
...@@ -180,16 +179,16 @@ class SHTOperator(LinearOperator): ...@@ -180,16 +179,16 @@ class SHTOperator(LinearOperator):
Parameters Parameters
---------- ----------
domain : TMP_Space, tuple of TMP_Space or TMP_Domain domain : Domainoid
The domain of the data that is input by "times" and output by The domain of the data that is input by "times" and output by
"adjoint_times". "adjoint_times".
target : TMP_Space, optional target : TMP_Space, optional
The target domain of the transform operation. The target (sub-)domain of the transform operation.
If omitted, a domain will be chosen automatically. If omitted, a domain will be chosen automatically.
Whenever the input domain of the transform is an RGSpace, the codomain Whenever the input domain of the transform is an RGSpace, the codomain
(and its parameters) are uniquely determined. (and its parameters) are uniquely determined.
For LMSpace, a GLSpace of sufficient resolution is chosen. For LMSpace, a GLSpace of sufficient resolution is chosen.
space : int, optional space : Subdomainoid
The index of the domain on which the operator should act The index of the domain on which the operator should act
If None, it is set to 0 if domain contains exactly one subdomain. If None, it is set to 0 if domain contains exactly one subdomain.
domain[space] must be a LMSpace. domain[space] must be a LMSpace.
...@@ -199,17 +198,15 @@ class SHTOperator(LinearOperator): ...@@ -199,17 +198,15 @@ class SHTOperator(LinearOperator):
# Initialize domain and target # Initialize domain and target
self._domain = TMP_Domain.make(domain) self._domain = TMP_Domain.make(domain)
self._capability = self.TIMES | self.ADJOINT_TIMES self._capability = self.TIMES | self.ADJOINT_TIMES
self._space = utilities.infer_space(self._domain, space) self._space = PartialDomain(self._domain, space)
hspc = self._domain[self._space] hspc = self._space.extractSpace(self._domain)
if not isinstance(hspc, LMSpace): if not isinstance(hspc, LMSpace):
raise TypeError("SHTOperator only works on a LMSpace domain") raise TypeError("SHTOperator only works on a LMSpace domain")
if target is None: if target is None:
target = hspc.get_default_codomain() target = hspc.get_default_codomain()
self._target = [dom for dom in self._domain] self._target = self._space.replaceSpace(self._domain,target)
self._target[self._space] = target
self._target = TMP_Domain.make(self._target)
hspc.check_codomain(target) hspc.check_codomain(target)
target.check_codomain(hspc) target.check_codomain(hspc)
...@@ -229,11 +226,15 @@ class SHTOperator(LinearOperator): ...@@ -229,11 +226,15 @@ class SHTOperator(LinearOperator):
def apply(self, x, mode): def apply(self, x, mode):
self._check_input(x, mode) self._check_input(x, mode)
if utilities.iscomplextype(x.dtype): name = self._space.extractName()
return (self._apply_spherical(x.real, mode) + x = x.to_dict()
1j*self._apply_spherical(x.imag, mode)) fld = x[name]
if utilities.iscomplextype(fld.dtype):
x[name] = (self._apply_spherical(fld.real, mode) +
1j*self._apply_spherical(fld.imag, mode))
else: else:
return self._apply_spherical(x, mode) x[name] = self._apply_spherical(fld, mode)
return TMP_Field.from_dict(x, self._tgt(mode))
def _slice_p2h(self, inp): def _slice_p2h(self, inp):
rr = self.sjob.alm2map_adjoint(inp) rr = self.sjob.alm2map_adjoint(inp)
...@@ -257,14 +258,17 @@ class SHTOperator(LinearOperator): ...@@ -257,14 +258,17 @@ class SHTOperator(LinearOperator):
res = self.sjob.alm2map(res) res = self.sjob.alm2map(res)
return res/np.sqrt(np.pi*4) return res/np.sqrt(np.pi*4)
def _apply_spherical(self, x, mode): def _apply_spherical(self, fld, mode):
axes = x.domain.axes[self._space] name = self._space.extractName()
v = x.val ispc = self._space.extractSpaceIndex()
spc = self._space.extractSpace(self._domain)
axes = fld.domain.axes[ispc]
v = fld.val
p2h = not x.domain[self._space].harmonic p2h = not fld.domain[ispc].harmonic
tdom = self._tgt(mode) tdom = self._tgt(mode)[name]
func = self._slice_p2h if p2h else self._slice_h2p func = self._slice_p2h if p2h else self._slice_h2p
odat = np.empty(tdom.shape, dtype=x.dtype) odat = np.empty(tdom.shape, dtype=fld.dtype)
for slice in utilities.get_slice_list(v.shape, axes): for slice in utilities.get_slice_list(v.shape, axes):
odat[slice] = func(v[slice]) odat[slice] = func(v[slice])
return TMP_fld(tdom, odat) return TMP_fld(tdom, odat)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment