Commit 6b35f65d authored by Martin Reinecke's avatar Martin Reinecke
Browse files

fix power projection; add initial adjointness tester

parent b7122aa0
Pipeline #19535 failed with stage
in 3 minutes and 57 seconds
......@@ -16,12 +16,13 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .. import Field, DomainTuple, nifty_utilities as utilities
from .. import Field, DomainTuple
from ..spaces import PowerSpace
from .linear_operator import LinearOperator
from .. import dobj
import numpy as np
class PowerProjectionOperator(LinearOperator):
def __init__(self, domain, power_space=None, space=None):
super(PowerProjectionOperator, self).__init__()
......@@ -31,7 +32,7 @@ class PowerProjectionOperator(LinearOperator):
if space is None and len(self._domain) == 1:
space = 0
space = int(space)
if space<0 or space>=len(self.domain):
if space < 0 or space >= len(self.domain):
raise ValueError("space index out of range")
hspace = self._domain[space]
if not hspace.harmonic:
......@@ -52,11 +53,13 @@ class PowerProjectionOperator(LinearOperator):
def _times(self, x):
pindex = self._target[self._space].pindex
pindex = pindex.reshape((1, pindex.size, 1))
arr = x.val.reshape(x.domain.collapsed_shape_for_domain(self._space))
arr = x.weight(1).val.reshape(
x.domain.collapsed_shape_for_domain(self._space))
out = dobj.zeros(self._target.collapsed_shape_for_domain(self._space),
dtype=x.dtype)
dtype=x.dtype)
np.add.at(out, (slice(None), pindex.ravel(), slice(None)), arr)
return Field(self._target, out.reshape(self._target.shape)).weight(-1,spaces=self._space)
return Field(self._target, out.reshape(self._target.shape))\
.weight(-1, spaces=self._space)
def _adjoint_times(self, x):
pindex = self._target[self._space].pindex
......
import unittest
import nifty2go as ift
import numpy as np
from itertools import product
from test.common import expand
from numpy.testing import assert_allclose
def _check_adjointness(op, dtype=np.float64):
f1 = ift.Field.from_random("normal",domain=op.domain, dtype=dtype)
f2 = ift.Field.from_random("normal",domain=op.target, dtype=dtype)
assert_allclose(f1.vdot(op.adjoint_times(f2)), op.times(f1).vdot(f2),
rtol=1e-8)
_harmonic_spaces = [ ift.RGSpace(7, distances=0.2, harmonic=True),
ift.RGSpace((12,46), distances=(0.2, 0.3), harmonic=True),
ift.LMSpace(17) ]
class Adjointness_Tests(unittest.TestCase):
@expand(product(_harmonic_spaces, [np.float64, np.complex128]))
def testPPO(self, sp, dtype):
op = ift.PowerProjectionOperator(sp)
_check_adjointness(op, dtype)
ps = ift.PowerSpace(sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=False, nbin=3))
op = ift.PowerProjectionOperator(sp, ps)
_check_adjointness(op, dtype)
ps = ift.PowerSpace(sp, ift.PowerSpace.useful_binbounds(sp, logarithmic=True, nbin=3))
op = ift.PowerProjectionOperator(sp, ps)
_check_adjointness(op, dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment