Commit b8bd4934 authored by Theo Steininger's avatar Theo Steininger

Merge branch 'byebye_fixed_point_voodoo' into 'master'

Byebye fixed point voodoo

See merge request !173
parents db23017b cbdbca99
Pipeline #15366 passed with stages
in 14 minutes and 6 seconds
...@@ -612,39 +612,47 @@ class Field(Loggable, Versionable, object): ...@@ -612,39 +612,47 @@ class Field(Loggable, Versionable, object):
# correct variance # correct variance
if preserve_gaussian_variance: if preserve_gaussian_variance:
assert issubclass(val.dtype.type, np.complexfloating),\
"complex input field is needed here"
h *= np.sqrt(2) h *= np.sqrt(2)
a *= np.sqrt(2) a *= np.sqrt(2)
if not issubclass(val.dtype.type, np.complexfloating): # The code below should not be needed in practice, since it would
# in principle one must not correct the variance for the fixed # only ever be called when hermitianizing a purely real field.
# points of the hermitianization. However, for a complex field # However it might be of educational use and keep us from forgetting
# the input field loses half of its power at its fixed points # how these things are done ...
# in the `hermitian` part. Hence, here a factor of sqrt(2) is
# also necessary! # if not issubclass(val.dtype.type, np.complexfloating):
# => The hermitianization can be done on a space level since # # in principle one must not correct the variance for the fixed
# either nothing must be done (LMSpace) or ALL points need a # # points of the hermitianization. However, for a complex field
# factor of sqrt(2) # # the input field loses half of its power at its fixed points
# => use the preserve_gaussian_variance flag in the # # in the `hermitian` part. Hence, here a factor of sqrt(2) is
# hermitian_decomposition method above. # # also necessary!
# # => The hermitianization can be done on a space level since
# This code is for educational purposes: # # either nothing must be done (LMSpace) or ALL points need a
fixed_points = [domain[i].hermitian_fixed_points() # # factor of sqrt(2)
for i in spaces] # # => use the preserve_gaussian_variance flag in the
fixed_points = [[fp] if fp is None else fp # # hermitian_decomposition method above.
for fp in fixed_points] #
# # This code is for educational purposes:
for product_point in itertools.product(*fixed_points): # fixed_points = [domain[i].hermitian_fixed_points()
slice_object = np.array((slice(None), )*len(val.shape), # for i in spaces]
dtype=np.object) # fixed_points = [[fp] if fp is None else fp
for i, sp in enumerate(spaces): # for fp in fixed_points]
point_component = product_point[i] #
if point_component is None: # for product_point in itertools.product(*fixed_points):
point_component = slice(None) # slice_object = np.array((slice(None), )*len(val.shape),
slice_object[list(domain_axes[sp])] = point_component # dtype=np.object)
# for i, sp in enumerate(spaces):
slice_object = tuple(slice_object) # point_component = product_point[i]
h[slice_object] /= np.sqrt(2) # if point_component is None:
a[slice_object] /= np.sqrt(2) # point_component = slice(None)
# slice_object[list(domain_axes[sp])] = point_component
#
# slice_object = tuple(slice_object)
# h[slice_object] /= np.sqrt(2)
# a[slice_object] /= np.sqrt(2)
return (h, a) return (h, a)
def _spec_to_rescaler(self, spec, result_list, power_space_index): def _spec_to_rescaler(self, spec, result_list, power_space_index):
......
...@@ -100,23 +100,26 @@ class RGSpace(Space): ...@@ -100,23 +100,26 @@ class RGSpace(Space):
self._distances = self._parse_distances(distances) self._distances = self._parse_distances(distances)
self._zerocenter = self._parse_zerocenter(zerocenter) self._zerocenter = self._parse_zerocenter(zerocenter)
def hermitian_fixed_points(self): # This code is unused but may be useful to keep around if it is ever needed
dimensions = len(self.shape) # again in the future ...
mid_index = np.array(self.shape)//2
ndlist = [1]*dimensions # def hermitian_fixed_points(self):
for k in range(dimensions): # dimensions = len(self.shape)
if self.shape[k] % 2 == 0: # mid_index = np.array(self.shape)//2
ndlist[k] = 2 # ndlist = [1]*dimensions
ndlist = tuple(ndlist) # for k in range(dimensions):
fixed_points = [] # if self.shape[k] % 2 == 0:
for index in np.ndindex(ndlist): # ndlist[k] = 2
for k in range(dimensions): # ndlist = tuple(ndlist)
if self.shape[k] % 2 != 0 and self.zerocenter[k]: # fixed_points = []
index = list(index) # for index in np.ndindex(ndlist):
index[k] = 1 # for k in range(dimensions):
index = tuple(index) # if self.shape[k] % 2 != 0 and self.zerocenter[k]:
fixed_points += [tuple(index * mid_index)] # index = list(index)
return fixed_points # index[k] = 1
# index = tuple(index)
# fixed_points += [tuple(index * mid_index)]
# return fixed_points
def hermitianize_inverter(self, x, axes): def hermitianize_inverter(self, x, axes):
# calculate the number of dimensions the input array has # calculate the number of dimensions the input array has
......
...@@ -161,19 +161,6 @@ class Space(DomainObject): ...@@ -161,19 +161,6 @@ class Space(DomainObject):
raise NotImplementedError( raise NotImplementedError(
"There is no generic co-smoothing kernel for Space base class.") "There is no generic co-smoothing kernel for Space base class.")
def hermitian_fixed_points(self):
""" Returns the array points which remain invariant under the action
of `hermitianize_inverter`
Returns
-------
list of index-tuples
The list contains the index-coordinates of the invariant points.
"""
return None
def hermitianize_inverter(self, x, axes): def hermitianize_inverter(self, x, axes):
""" Inverts/flips x in the context of Hermitian decomposition. """ Inverts/flips x in the context of Hermitian decomposition.
......
...@@ -67,6 +67,8 @@ class Test_Functionality(unittest.TestCase): ...@@ -67,6 +67,8 @@ class Test_Functionality(unittest.TestCase):
r2 = RGSpace(s2, harmonic=True, zerocenter=(z2,)) r2 = RGSpace(s2, harmonic=True, zerocenter=(z2,))
ra = RGSpace(s1+s2, harmonic=True, zerocenter=(z1, z2)) ra = RGSpace(s1+s2, harmonic=True, zerocenter=(z1, z2))
if preserve:
complexdata=True
v = np.random.random(s1+s2) v = np.random.random(s1+s2)
if complexdata: if complexdata:
v = v + 1j*np.random.random(s1+s2) v = v + 1j*np.random.random(s1+s2)
......
...@@ -127,7 +127,3 @@ class LMSpaceFunctionalityTests(unittest.TestCase): ...@@ -127,7 +127,3 @@ class LMSpaceFunctionalityTests(unittest.TestCase):
def test_distance_array(self, lmax, expected): def test_distance_array(self, lmax, expected):
l = LMSpace(lmax) l = LMSpace(lmax)
assert_almost_equal(l.get_distance_array('not').data, expected) assert_almost_equal(l.get_distance_array('not').data, expected)
def test_hermitian_fixed_points(self):
x = LMSpace(5)
assert_equal(x.hermitian_fixed_points(), None)
...@@ -190,8 +190,3 @@ class RGSpaceFunctionalityTests(unittest.TestCase): ...@@ -190,8 +190,3 @@ class RGSpaceFunctionalityTests(unittest.TestCase):
assert_almost_equal(res, expected) assert_almost_equal(res, expected)
if inplace: if inplace:
assert_(x is res) assert_(x is res)
def test_hermitian_fixed_points(self):
x = RGSpace((5, 6, 5, 6), zerocenter=[False, False, True, True])
assert_equal(x.hermitian_fixed_points(),
[(0, 0, 2, 0), (0, 0, 2, 3), (0, 3, 2, 0), (0, 3, 2, 3)])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment