Commit 8331847f authored by Martin Reinecke's avatar Martin Reinecke
Browse files

introduce PowerProjectionOperator

parent 09c9b645
Pipeline #19441 failed with stage
in 3 minutes and 59 seconds
......@@ -28,6 +28,11 @@ class DomainTuple(object):
shape_tuple = tuple(sp.shape for sp in self._dom)
self._shape = reduce(lambda x, y: x + y, shape_tuple, ())
self._dim = reduce(lambda x, y: x * y, self._shape, 1)
self._accdims = (1,)
prod = 1
for dom in self._dom:
prod *= dom.dim
self._accdims += (prod,)
def _get_axes_tuple(self):
i = 0
......@@ -110,3 +115,14 @@ class DomainTuple(object):
for i in self.domains:
res += "\n" + str(i)
return res
def collapsed_shape_for_domain(self, ispace):
"""Returns a three-component shape, with the total number of pixels
in the domains before the requested space in res[0], the total number
of pixels in the requested space in res[1], and the remaining pixels in
res[2].
"""
dims = (dom.dim for dom in self._dom)
return (self._accdims[ispace],
self._accdims[ispace+1]/self._accdims[ispace],
self._accdims[-1]/self._accdims[ispace+1])
......@@ -26,7 +26,6 @@ from .domain_tuple import DomainTuple
from functools import reduce
from . import dobj
class Field(object):
""" The discrete representation of a continuous field over multiple spaces.
......@@ -219,27 +218,19 @@ class Field(object):
idx=space_index,
binbounds=binbounds)
for part in parts]
parts = [ part.weight(-1,spaces) for part in parts ]
return parts[0] + 1j*parts[1] if keep_phase_information else parts[0]
@staticmethod
def _single_power_analyze(field, idx, binbounds):
from .operators.power_projection_operator import PowerProjectionOperator
power_domain = PowerSpace(field.domain[idx], binbounds)
pindex = power_domain.pindex
axes = field.domain.axes[idx]
new_pindex_shape = [1] * len(field.shape)
for i, ax in enumerate(axes):
new_pindex_shape[ax] = pindex.shape[i]
pindex = np.broadcast_to(pindex.reshape(new_pindex_shape), field.shape)
power_spectrum = dobj.bincount_axis(pindex, weights=field.val,
axis=axes)
result_domain = list(field.domain)
result_domain[idx] = power_domain
return Field(result_domain, power_spectrum)
ppo = PowerProjectionOperator(field.domain,idx,power_domain)
return ppo(field)
def _compute_spec(self, spaces):
from .operators.power_projection_operator import PowerProjectionOperator
from .basic_arithmetics import sqrt
if spaces is None:
spaces = range(len(self.domain))
else:
......@@ -247,23 +238,14 @@ class Field(object):
# create the result domain
result_domain = list(self.domain)
spec = sqrt(self)
for i in spaces:
if not isinstance(self.domain[i], PowerSpace):
raise ValueError("A PowerSpace is needed for field "
"synthetization.")
result_domain[i] = self.domain[i].harmonic_partner
ppo = PowerProjectionOperator(result_domain,i,self.domain[i])
spec = ppo.adjoint_times(spec)
spec = dobj.sqrt(self.val)
for i in spaces:
power_space = self.domain[i]
local_blow_up = [slice(None)]*len(spec.shape)
# it is important to count from behind, since spec potentially
# grows with every iteration
index = self.domain.axes[i][0]-len(self.shape)
local_blow_up[index] = power_space.pindex
# here, the power_spectrum is distributed into the new shape
spec = spec[local_blow_up]
return Field(result_domain, val=spec)
return spec
def power_synthesize(self, spaces=None, real_power=True, real_signal=True):
""" Yields a sampled field with `self`**2 as its power spectrum.
......
......@@ -38,3 +38,5 @@ from .response_operator import ResponseOperator
from .laplace_operator import LaplaceOperator
from .smoothness_operator import SmoothnessOperator
from .power_projection_operator import PowerProjectionOperator
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from .. import Field, DomainTuple, nifty_utilities as utilities
from ..spaces import PowerSpace
from .linear_operator import LinearOperator
from .. import dobj
import numpy as np
class PowerProjectionOperator(LinearOperator):
def __init__(self, domain, space, power_space=None):
super(PowerProjectionOperator, self).__init__()
# Initialize domain and target
self._domain = DomainTuple.make(domain)
if space is None and len(self._domain) == 1:
space = 0
space = int(space)
if space<0 or space>=len(self.domain):
raise ValueError("space index out of range")
hspace = self._domain[space]
if not hspace.harmonic:
raise ValueError("Operator acts on harmonic spaces only")
if isinstance(hspace, PowerSpace):
raise TypeError("Operator cannot act on PowerSpaces")
if power_space is None:
power_space = PowerSpace(hspace)
else:
if not isinstance(power_space, PowerSpace):
raise TypeError("power_space argument must be a PowerSpace")
if power_space.harmonic_partner != hspace:
raise ValueError("power_space does not match its partner")
self._space = space
tgt = list(self._domain)
tgt[self._space] = power_space
self._target = DomainTuple.make(tgt)
def _times(self, x):
pindex = self._target[self._space].pindex
pindex = pindex.reshape((1, pindex.size, 1))
arr = x.val.reshape(x.domain.collapsed_shape_for_domain(self._space))
out = dobj.zeros(self._target.collapsed_shape_for_domain(self._space),
dtype=x.dtype)
np.add.at(out, (slice(None), pindex.ravel(), slice(None)), arr)
return Field(self._target, out.reshape(self._target.shape)).weight(-1,spaces=self._space)
def _adjoint_times(self, x):
pindex = self._target[self._space].pindex
pindex = pindex.reshape((1, pindex.size, 1))
arr = x.val.reshape(x.domain.collapsed_shape_for_domain(self._space))
out = dobj.zeros(self._domain.collapsed_shape_for_domain(self._space),
dtype=x.dtype)
out[()] = arr[(slice(None), pindex.ravel(), slice(None))]
return Field(self._domain, out.reshape(self._domain.shape))
@property
def domain(self):
return self._domain
@property
def target(self):
return self._target
@property
def unitary(self):
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment