Commit da51048a authored by Theo Steininger's avatar Theo Steininger
Browse files

Merge branch 'byebye_zerocenter' into 'nightly'

Byebye zerocenter

See merge request !194
parents fb6d3eb1 0cfd3b1c
......@@ -29,8 +29,7 @@ from .transformations import RGRGTransformation,\
LMGLTransformation,\
LMHPTransformation,\
GLLMTransformation,\
HPLMTransformation,\
TransformationCache
HPLMTransformation
class FFTOperator(LinearOperator):
......@@ -57,7 +56,7 @@ class FFTOperator(LinearOperator):
"adjoint_times".
If omitted, a co-domain will be chosen automatically.
Whenever "domain" is an RGSpace, the codomain (and its parameters) are
uniquely determined (except for "zerocenter").
uniquely determined.
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
co-domain is chosen that should work satisfactorily in most situations,
but for full control, the user should explicitly specify a codomain.
......@@ -135,11 +134,11 @@ class FFTOperator(LinearOperator):
backward_class = self.transformation_dictionary[
(self.target[0].__class__, self.domain[0].__class__)]
self._forward_transformation = TransformationCache.create(
forward_class, self.domain[0], self.target[0], module=module)
self._forward_transformation = forward_class(
self.domain[0], self.target[0], module=module)
self._backward_transformation = TransformationCache.create(
backward_class, self.target[0], self.domain[0], module=module)
self._backward_transformation = backward_class(
self.target[0], self.domain[0], module=module)
# Store the dtype information
self.domain_dtype = \
......@@ -233,7 +232,7 @@ class FFTOperator(LinearOperator):
A (more or less perfect) counterpart to "domain" with respect
to a FFT operation.
Whenever "domain" is an RGSpace, the codomain (and its parameters)
are uniquely determined (except for "zerocenter").
are uniquely determined.
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
co-domain is chosen that should work satisfactorily in most
situations. For full control however, the user should not rely on
......
......@@ -21,5 +21,3 @@ from .gllmtransformation import GLLMTransformation
from .hplmtransformation import HPLMTransformation
from .lmgltransformation import LMGLTransformation
from .lmhptransformation import LMHPTransformation
from .transformation_cache import TransformationCache
......@@ -27,7 +27,6 @@ from ....config import nifty_configuration as gc
from .... import nifty_utilities as utilities
from keepers import Loggable
from functools import reduce
fftw = gdi.get('fftw')
......@@ -41,150 +40,6 @@ class Transform(Loggable, object):
self.domain = domain
self.codomain = codomain
# initialize the dictionary which stores the values from
# get_centering_mask
self.centering_mask_dict = {}
def get_centering_mask(self, to_center_input, dimensions_input,
offset_input=False):
"""
Computes the mask, used to (de-)zerocenter domain and target
fields.
Parameters
----------
to_center_input : tuple, list, numpy.ndarray
A tuple of booleans which dimensions should be
zero-centered.
dimensions_input : tuple, list, numpy.ndarray
A tuple containing the mask's desired shape.
offset_input : int, boolean
Specifies whether the zero-th dimension starts with an odd
or and even index, i.e. if it is shifted.
Returns
-------
result : np.ndarray
A 1/-1-alternating mask.
"""
# cast input
to_center = np.array(to_center_input)
dimensions = np.array(dimensions_input)
# if none of the dimensions are zero centered, return a 1
if np.all(to_center == 0):
return 1
if np.all(dimensions == np.array(1)) or \
np.all(dimensions == np.array([1])):
return dimensions
# The dimensions of size 1 must be sorted out for computing the
# centering_mask. The depth of the array will be restored in the
# end.
size_one_dimensions = []
temp_dimensions = []
temp_to_center = []
for i in range(len(dimensions)):
if dimensions[i] == 1:
size_one_dimensions += [True]
else:
size_one_dimensions += [False]
temp_dimensions += [dimensions[i]]
temp_to_center += [to_center[i]]
dimensions = np.array(temp_dimensions)
to_center = np.array(temp_to_center)
# cast the offset_input into the shape of to_center
offset = np.zeros(to_center.shape, dtype=int)
# if the first dimension has length 1 and has an offset, restore the
# global minus by hand
if not size_one_dimensions[0]:
offset[0] = int(offset_input)
# check for dimension match
if to_center.size != dimensions.size:
raise TypeError(
'The length of the supplied lists does not match.')
# build up the value memory
# compute an identifier for the parameter set
temp_id = tuple(
(tuple(to_center), tuple(dimensions), tuple(offset)))
if temp_id not in self.centering_mask_dict:
# use np.tile in order to stack the core alternation scheme
# until the desired format is constructed.
core = np.fromfunction(
lambda *args: (-1) **
(np.tensordot(to_center,
args +
offset.reshape(offset.shape +
(1,) *
(np.array(
args).ndim - 1)),
1)),
(2,) * to_center.size)
# Cast the core to the smallest integers we can get
core = core.astype(np.int8)
centering_mask = np.tile(core, dimensions // 2)
# for the dimensions of odd size corresponding slices must be
# added
for i in range(centering_mask.ndim):
# check if the size of the certain dimension is odd or even
if (dimensions % 2)[i] == 0:
continue
# prepare the slice object
temp_slice = (slice(None),) * i + (slice(-2, -1, 1),) + \
(slice(None),) * (centering_mask.ndim - 1 - i)
# append the slice to the centering_mask
centering_mask = np.append(centering_mask,
centering_mask[temp_slice],
axis=i)
# Add depth to the centering_mask where the length of a
# dimension was one
temp_slice = ()
for i in range(len(size_one_dimensions)):
if size_one_dimensions[i]:
temp_slice += (None,)
else:
temp_slice += (slice(None),)
centering_mask = centering_mask[temp_slice]
# if the first dimension has length 1 and has an offset, restore
# the global minus by hand
if size_one_dimensions[0] and offset_input:
centering_mask *= -1
self.centering_mask_dict[temp_id] = centering_mask
return self.centering_mask_dict[temp_id]
def _apply_mask(self, val, mask, axes):
"""
Apply centering mask to an array.
Parameters
----------
val: distributed_data_object or numpy.ndarray
The value-array on which the mask should be applied.
mask: numpy.ndarray
The mask to be applied.
axes: tuple
The axes which are to be transformed.
Returns
-------
distributed_data_object or np.nd_array
Mask input array by multiplying it with the mask.
"""
# reshape mask if necessary
if axes:
mask = mask.reshape(
[y if x in axes else 1
for x, y in enumerate(val.shape)]
)
return val * mask
def transform(self, val, axes, **kwargs):
"""
A generic ff-transform function.
......@@ -246,11 +101,6 @@ class MPIFFT(Transform):
return self.info_dict[temp_id]
def _atomic_mpi_transform(self, val, info, axes):
# Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.zerocenter):
temp_val = np.copy(val)
val = self._apply_mask(temp_val, info.cmask_codomain, axes)
p = info.plan
# Load the value into the plan
if p.has_input:
......@@ -272,13 +122,6 @@ class MPIFFT(Transform):
else:
return None
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.zerocenter):
result = self._apply_mask(result, info.cmask_domain, axes)
# Correct the sign if needed
result *= info.sign
return result
def _local_transform(self, val, axes, **kwargs):
......@@ -299,27 +142,12 @@ class MPIFFT(Transform):
is_local=True,
**kwargs)
# Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.zerocenter):
temp_val = np.copy(local_val)
local_val = self._apply_mask(temp_val,
current_info.cmask_codomain, axes)
local_result = current_info.fftw_interface(
local_val,
axes=axes,
planner_effort='FFTW_ESTIMATE'
)
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.zerocenter):
local_result = self._apply_mask(local_result,
current_info.cmask_domain, axes)
# Correct the sign if needed
if current_info.sign != 1:
local_result *= current_info.sign
try:
# Create return object and insert results inplace
return_val = val.copy_empty(global_shape=val.shape,
......@@ -470,36 +298,6 @@ class FFTWTransformInfo(object):
raise ImportError(
"The MPI FFTW module is needed but not available.")
shape = (local_shape if axes is None else
[y for x, y in enumerate(local_shape) if x in axes])
self._cmask_domain = fftw_context.get_centering_mask(domain.zerocenter,
shape,
local_offset_Q)
self._cmask_codomain = fftw_context.get_centering_mask(
codomain.zerocenter,
shape,
local_offset_Q)
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
self._sign = (-1) ** np.sum(np.array(domain.zerocenter) *
np.array(codomain.zerocenter) *
(np.array(domain.shape) // 2 % 2))
@property
def cmask_domain(self):
return self._cmask_domain
@property
def cmask_codomain(self):
return self._cmask_codomain
@property
def sign(self):
return self._sign
class FFTWLocalTransformInfo(FFTWTransformInfo):
def __init__(self, domain, codomain, axes, local_shape,
......@@ -623,20 +421,6 @@ class SerialFFT(Transform):
return return_val
def _atomic_transform(self, local_val, axes, local_offset_Q):
# some auxiliaries for the mask computation
local_shape = local_val.shape
shape = (local_shape if axes is None else
[y for x, y in enumerate(local_shape) if x in axes])
# Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.zerocenter):
temp_val = np.copy(local_val)
mask = self.get_centering_mask(self.codomain.zerocenter,
shape,
local_offset_Q)
local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation
if self._use_fftw:
if self.codomain.harmonic:
......@@ -651,19 +435,4 @@ class SerialFFT(Transform):
else:
result_val = np.fft.ifftn(local_val, axes=axes)
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.zerocenter):
mask = self.get_centering_mask(self.domain.zerocenter,
shape,
local_offset_Q)
result_val = self._apply_mask(result_val, mask, axes)
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
sign = (-1) ** np.sum(np.array(self.domain.zerocenter) *
np.array(self.codomain.zerocenter) *
(np.array(self.domain.shape) // 2 % 2))
if sign != 1:
result_val *= sign
return result_val
......@@ -53,7 +53,7 @@ class RGRGTransformation(Transformation):
return True
@classmethod
def get_codomain(cls, domain, zerocenter=None):
def get_codomain(cls, domain):
"""
Generates a compatible codomain to which transformations are
reasonable, i.e.\ either a shifted grid or a Fourier conjugate
......@@ -63,9 +63,6 @@ class RGRGTransformation(Transformation):
----------
domain: RGSpace
Space for which a codomain is to be generated
zerocenter : {bool, numpy.ndarray}, *optional*
Whether or not the grid is zerocentered for each axis or not
(default: None).
Returns
-------
......@@ -75,21 +72,11 @@ class RGRGTransformation(Transformation):
if not isinstance(domain, RGSpace):
raise TypeError("domain needs to be a RGSpace")
# parse the zerocenter input
if zerocenter is None:
zerocenter = domain.zerocenter
# if the input is something scalar, cast it to a boolean
else:
temp = np.empty_like(domain.zerocenter)
temp[:] = zerocenter
zerocenter = temp
# calculate the initialization parameters
distances = 1. / (np.array(domain.shape) *
np.array(domain.distances))
new_space = RGSpace(domain.shape,
zerocenter=zerocenter,
distances=distances,
harmonic=(not domain.harmonic))
......
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
# Copyright(C) 2013-2017 Max-Planck-Society
#
# NIFTy is being developed at the Max-Planck-Institut fuer Astrophysik
# and financially supported by the Studienstiftung des deutschen Volkes.
from builtins import object
class _TransformationCache(object):
def __init__(self):
self.cache = {}
def create(self, transformation_class, domain, codomain, module):
key = (domain, codomain, module)
if key not in self.cache:
self.cache[key] = transformation_class(domain, codomain, module)
return self.cache[key]
TransformationCache = _TransformationCache()
......@@ -42,6 +42,4 @@ class RG1DPlotter(PlotterBase):
num=num,
endpoint=False)
if rgspace.zerocenter[0]:
xy_data[0] -= np.floor(length/2)
return xy_data
......@@ -43,10 +43,6 @@ class RGSpace(Space):
----------
shape : {int, numpy.ndarray}
Number of grid points or numbers of gridpoints along each axis.
zerocenter : {bool, numpy.ndarray} *optional*
Whether x==0 (or k==0, respectively) is located in the center of
the grid (or the center of each axis speparately) or not.
(default: False).
distances : {float, numpy.ndarray}, *optional*
Distance between two grid points along each axis
(default: None).
......@@ -63,9 +59,6 @@ class RGSpace(Space):
----------
harmonic : bool
Whether or not the grid represents a position or harmonic space.
zerocenter : tuple of bool
Whether x==0 (or k==0, respectively) is located in the center of
the grid (or the center of each axis speparately) or not.
distances : tuple of floats
Distance between two grid points along the correponding axis.
dim : np.int
......@@ -81,15 +74,13 @@ class RGSpace(Space):
# ---Overwritten properties and methods---
def __init__(self, shape, zerocenter=False, distances=None,
harmonic=False):
def __init__(self, shape, distances=None, harmonic=False):
self._harmonic = bool(harmonic)
super(RGSpace, self).__init__()
self._shape = self._parse_shape(shape)
self._distances = self._parse_distances(distances)
self._zerocenter = self._parse_zerocenter(zerocenter)
# This code is unused but may be useful to keep around if it is ever needed
# again in the future ...
......@@ -130,12 +121,8 @@ class RGSpace(Space):
i = axes[k]
slice_picker = slice_primitive[:]
slice_inverter = slice_primitive[:]
if (not self.zerocenter[k]) or self.shape[k] % 2 == 0:
slice_picker[i] = slice(1, None, None)
slice_inverter[i] = slice(None, 0, -1)
else:
slice_picker[i] = slice(None)
slice_inverter[i] = slice(None, None, -1)
slice_picker[i] = slice(1, None, None)
slice_inverter[i] = slice(None, 0, -1)
slice_picker = tuple(slice_picker)
slice_inverter = tuple(slice_inverter)
......@@ -149,8 +136,8 @@ class RGSpace(Space):
# ---Mandatory properties and methods---
def __repr__(self):
return ("RGSpace(shape=%r, zerocenter=%r, distances=%r, harmonic=%r)"
% (self.shape, self.zerocenter, self.distances, self.harmonic))
return ("RGSpace(shape=%r, distances=%r, harmonic=%r)"
% (self.shape, self.distances, self.harmonic))
@property
def harmonic(self):
......@@ -170,7 +157,6 @@ class RGSpace(Space):
def copy(self):
return self.__class__(shape=self.shape,
zerocenter=self.zerocenter,
distances=self.distances,
harmonic=self.harmonic)
......@@ -234,16 +220,13 @@ class RGSpace(Space):
dists = (cords[0] - shape[0]//2)*dk[0]
dists *= dists
# apply zerocenterQ shift
if not self.zerocenter[0]:
dists = np.fft.ifftshift(dists)
dists = np.fft.ifftshift(dists)
# only save the individual slice
dists = dists[slice_of_first_dimension]
for ii in range(1, len(shape)):
temp = (cords[ii] - shape[ii] // 2) * dk[ii]
temp *= temp
if not self.zerocenter[ii]:
temp = np.fft.ifftshift(temp)
temp = np.fft.ifftshift(temp)
dists = dists + temp
dists = np.sqrt(dists)
return dists
......@@ -298,21 +281,6 @@ class RGSpace(Space):
return self._distances
@property
def zerocenter(self):
"""Returns True if grid points lie symmetrically around zero.
Returns
-------
bool
True if the grid points are centered around the 0 grid point. This
option is most common for harmonic spaces (where both conventions
are used) but may be used for position spaces, too.
"""
return self._zerocenter
def _parse_shape(self, shape):
if np.isscalar(shape):
shape = (shape,)
......@@ -331,18 +299,10 @@ class RGSpace(Space):
temp[:] = distances
return tuple(temp)
def _parse_zerocenter(self, zerocenter):
temp = np.empty(len(self.shape), dtype=bool)
temp[:] = zerocenter
if np.any(np.logical_and(temp, np.array(self.shape) % 2)):
raise ValueError("All zerocentered axis must have even length!")
return tuple(temp)
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group['shape'] = self.shape
hdf5_group['zerocenter'] = self.zerocenter
hdf5_group['distances'] = self.distances
hdf5_group['harmonic'] = self.harmonic
......@@ -352,7 +312,6 @@ class RGSpace(Space):
def _from_hdf5(cls, hdf5_group, repository):
result = cls(
shape=hdf5_group['shape'][:],
zerocenter=hdf5_group['zerocenter'][:],
distances=hdf5_group['distances'][:],
harmonic=hdf5_group['harmonic'][()],
)
......
......@@ -61,16 +61,12 @@ class Test_Interface(unittest.TestCase):
class Test_Functionality(unittest.TestCase):
@expand(product([True, False], [True, False],
[True, False], [True, False],
[(1,), (4,), (5,)], [(1,), (6,), (7,)]))
def test_hermitian_decomposition(self, z1, z2, preserve, complexdata,
s1, s2):
try:
r1 = RGSpace(s1, harmonic=True, zerocenter=(z1,))
r2 = RGSpace(s2, harmonic=True, zerocenter=(z2,))
ra = RGSpace(s1+s2, harmonic=True, zerocenter=(z1, z2))
except ValueError:
raise SkipTest
def test_hermitian_decomposition(self, preserve, complexdata, s1, s2):
np.random.seed(123)
r1 = RGSpace(s1, harmonic=True)
r2 = RGSpace(s2, harmonic=True)
ra = RGSpace(s1+s2, harmonic=True)
if preserve:
complexdata=True
......@@ -91,12 +87,9 @@ class Test_Functionality(unittest.TestCase):
assert_almost_equal(h1.get_full_data(), h3.get_full_data())
assert_almost_equal(a1.get_full_data(), a3.get_full_data())
@expand(product([RGSpace((8,), harmonic=True,
zerocenter=False),
RGSpace((8, 8), harmonic=True, distances=0.123,
zerocenter=True)],
[RGSpace((8,), harmonic=True,
zerocenter=False),
@expand(product([RGSpace((8,), harmonic=True),
RGSpace((8, 8), harmonic=True, distances=0.123)],
[RGSpace((8,), harmonic=True),
LMSpace(12)],
['real', 'complex']))
def test_power_synthesize_analyze(self, space1, space2, base):
......
......@@ -32,9 +32,8 @@ from nifty import dependency_injector as gdi
from nose.plugins.skip import SkipTest