Commit 76c0839f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

respect the spaces key word

parent 29ae4690
from ... import Field,\
FieldArray
from ..linear_operator import LinearOperator
from ...spaces.power_space import PowerSpace
class PowerProjection(LinearOperator):
def __init__(self, domain, target, spaces=0, default_spaces=None):
def __init__(self, domain, target, default_spaces=None):
self._domain = self._parse_domain(domain)
self._target = self._parse_domain(target)
self.pindex = self.target[spaces].pindex
if len(self._domain)!=1 or len(self._target)!=1:
raise ValueError("Operator only works over one space")
if not self._domain[0].harmonic:
raise ValueError("domain must be a harmonic space")
if not isinstance(self._target[0], PowerSpace):
raise ValueError("target must be a PowerSpace")
self.pindex = self.target[0].pindex
super(PowerProjection, self).__init__(default_spaces)
def _times(self,x,spaces):
projected_x = self.pindex.bincount(weights=x.weight(1).val.real)
y = Field(self.target, val=projected_x).weight(-1)
if spaces is None:
spaces = 0
projected_x = self.pindex.bincount(
weights=x.weight(1,spaces=spaces).val.real,
axis=x.domain_axes[spaces])
tgt_domain = list(x.domain)
tgt_domain[spaces] = self._target[0]
y = Field(tgt_domain, val=projected_x).weight(-1, spaces=spaces)
return y
def _adjoint_times(self,x,spaces):
if spaces is None:
spaces = 0
y = Field(self.domain, val=1.)
axes = x.domain_axes
tgt_domain = list(x.domain)
tgt_domain[spaces] = self._domain[0]
y = Field(tgt_domain, val=1.)
axes = x.domain_axes[spaces]
spec = x.val.get_full_data()
spec = self._spec_to_rescaler(spec, spaces, axes)
......@@ -49,7 +63,7 @@ class PowerProjection(LinearOperator):
local_pindex = self.pindex.get_local_data(copy=False)
local_blow_up = [slice(None)]*len(spec.shape)
local_blow_up[axes[power_space_index][0]] = local_pindex
local_blow_up[axes[power_space_index]] = local_pindex
# here, the power_spectrum is distributed into the new shape
local_rescaler = spec[local_blow_up]
return local_rescaler
......@@ -63,4 +77,4 @@ class PowerProjection(LinearOperator):
@property
def unitary(self):
return False
\ No newline at end of file
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment