Commit bfe9690d authored by Theo Steininger's avatar Theo Steininger
Browse files

Added volume weight option to PowerSpace

parent ef5dde19
......@@ -22,7 +22,8 @@ from keepers import Loggable
from future.utils import with_metaclass
class Energy(with_metaclass(NiftyMeta, type('NewBase', (Loggable, object), {}))):
class Energy(with_metaclass(NiftyMeta,
type('NewBase', (Loggable, object), {}))):
""" Provides the functional used by minimization schemes.
The Energy object is an implementation of a scalar function including its
......
......@@ -26,14 +26,14 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
"""
def __init__(self, R, N, S, d, position, inverter=None,
preconditioner=None, fft4exp=None, offset=0., **kwargs):
preconditioner=None, fft4exp=None, prefactor=None, **kwargs):
self._cache = {}
self.R = R
self.N = N
self.S = S
self.d = d
self.position = position
self.offset = offset
self.prefactor = prefactor
if preconditioner is None:
preconditioner = self.S.times
self._domain = self.S.domain
......@@ -56,7 +56,7 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
copy.N = self.N.copy()
copy.S = self.S.copy()
copy.d = self.d.copy()
copy.offset = self.offset
copy.prefactor = self.prefactor
if 'position' in kwargs:
copy.position = kwargs['position']
else:
......@@ -94,7 +94,10 @@ class LogNormalWienerFilterCurvature(InvertibleOperatorMixin,
@property
@memo
def _expp_sspace(self):
return clipped_exp(self._fft(self.position) - self.offset)
result = clipped_exp(self._fft(self.position))
if self.prefactor is not None:
result *= self.prefactor
return result
@property
@memo
......
......@@ -25,13 +25,13 @@ class LogNormalWienerFilterEnergy(Energy):
"""
def __init__(self, position, d, R, N, S, fft4exp=None, old_curvature=None,
offset=0.):
prefactor=None):
super(LogNormalWienerFilterEnergy, self).__init__(position=position)
self.d = d
self.R = R
self.N = N
self.S = S
self.offset = offset
self.prefactor = prefactor
if fft4exp is None:
self._fft = create_composed_fft_operator(self.S.domain,
......@@ -46,7 +46,7 @@ class LogNormalWienerFilterEnergy(Energy):
return self.__class__(position=position, d=self.d, R=self.R, N=self.N,
S=self.S, fft4exp=self._fft,
old_curvature=self._curvature,
offset=self.offset)
prefactor=self.prefactor)
@property
@memo
......@@ -70,7 +70,7 @@ class LogNormalWienerFilterEnergy(Energy):
d=self.d,
position=self.position,
fft4exp=self._fft,
offset=self.offset)
prefactor=self.prefactor)
else:
self._curvature = \
self._old_curvature.copy(position=self.position)
......
......@@ -86,9 +86,10 @@ class PowerSpace(Space):
# ---Overwritten properties and methods---
def __init__(self, harmonic_partner, distribution_strategy=None,
binbounds=None):
volume_type='rho', binbounds=None):
super(PowerSpace, self).__init__()
self._ignore_for_hash += ['_pindex', '_kindex', '_rho']
self._ignore_for_hash += ['_pindex', '_kindex', '_rho',
'_volume_weight']
if distribution_strategy is None:
distribution_strategy = gc['default_distribution_strategy']
......@@ -127,6 +128,8 @@ class PowerSpace(Space):
(self._binbounds, self._pindex, self._kindex, self._rho) = \
self._powerIndexCache[key]
self.volume_type = str(volume_type)
@staticmethod
def _compute_pindex(harmonic_partner, distance_array, binbounds,
distribution_strategy):
......@@ -199,11 +202,31 @@ class PowerSpace(Space):
binbounds=self._binbounds)
def weight(self, x, power, axes, inplace=False):
if self.volume_type == 'unit':
if inplace:
return x
else:
return x.copy()
if self.volume_type == 'rho':
weight = self.rho
elif self.volume_type == 'volume':
try:
weight = self._volume_weight
except AttributeError:
k = self.kindex
weight = np.empty_like(k)
weight[1:-1] = (k[2:] - k[:-2])/2
weight[0] = k[1] - k[0]
weight[-1] = k[-1] - k[-2]
self._volume_weight = weight
reshaper = [1, ] * len(x.shape)
# we know len(axes) is always 1
reshaper[axes[0]] = self.shape[0]
weight = self.rho.reshape(reshaper)
weight = weight.reshape(reshaper)
if power != 1:
weight = weight ** np.float(power)
......
......@@ -16,18 +16,6 @@
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
"""
.. __ ____ __
.. /__/ / _/ / /_
.. __ ___ __ / /_ / _/ __ __
.. / _ | / / / _/ / / / / / /
.. / / / / / / / / / /_ / /_/ /
.. /__/ /__/ /__/ /__/ \___/ \___ / rg
.. /______/
NIFTY submodule for regular Cartesian grids.
"""
from __future__ import division
from builtins import range
from functools import reduce
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment