Skip to content
Snippets Groups Projects
Commit 1e7b0830 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

more disentangling

parent 181e0e4d
No related branches found
No related tags found
1 merge request!157Fix hermitianizer
Pipeline #
...@@ -102,6 +102,12 @@ class RGSpace(Space): ...@@ -102,6 +102,12 @@ class RGSpace(Space):
def hermitian_decomposition(self, x, axes=None, def hermitian_decomposition(self, x, axes=None,
preserve_gaussian_variance=False): preserve_gaussian_variance=False):
# check axes
if axes is None:
axes = range(len(x.shape))
assert len(x.shape) >= len(self.shape), "shapes mismatch"
assert len(axes) == len(self.shape), "axes mismatch"
# compute the hermitian part # compute the hermitian part
flipped_x = self._hermitianize_inverter(x, axes=axes) flipped_x = self._hermitianize_inverter(x, axes=axes)
flipped_x = flipped_x.conjugate() flipped_x = flipped_x.conjugate()
...@@ -137,10 +143,9 @@ class RGSpace(Space): ...@@ -137,10 +143,9 @@ class RGSpace(Space):
dimensions = mid_index.size dimensions = mid_index.size
# Use ndindex to iterate over all combinations of zeros and the # Use ndindex to iterate over all combinations of zeros and the
# mid_index in order to correct all fixed points. # mid_index in order to correct all fixed points.
if axes is None:
axes = xrange(dimensions)
ndlist = [2 if i in axes else 1 for i in xrange(dimensions)] ndlist = [2 if i in axes and self.shape[i] % 2 == 0
else 1 for i in xrange(dimensions)]
ndlist = tuple(ndlist) ndlist = tuple(ndlist)
for i in np.ndindex(ndlist): for i in np.ndindex(ndlist):
temp_index = tuple(i * mid_index) temp_index = tuple(i * mid_index)
...@@ -149,22 +154,18 @@ class RGSpace(Space): ...@@ -149,22 +154,18 @@ class RGSpace(Space):
return hermitian_part, anti_hermitian_part return hermitian_part, anti_hermitian_part
def _hermitianize_inverter(self, x, axes): def _hermitianize_inverter(self, x, axes):
shape = x.shape
# calculate the number of dimensions the input array has # calculate the number of dimensions the input array has
dimensions = len(shape) dimensions = len(x.shape)
# prepare the slicing object which will be used for mirroring # prepare the slicing object which will be used for mirroring
slice_primitive = [slice(None), ] * dimensions slice_primitive = [slice(None), ] * dimensions
# copy the input data # copy the input data
y = x.copy() y = x.copy()
if axes is None:
axes = xrange(dimensions)
# flip in the desired directions # flip in the desired directions
for i in axes: for i in axes:
slice_picker = slice_primitive[:] slice_picker = slice_primitive[:]
slice_inverter = slice_primitive[:] slice_inverter = slice_primitive[:]
if self.zerocenter[i] == False or shape[i] % 2 == 0: if self.zerocenter[i] is False or self.shape[i] % 2 == 0:
slice_picker[i] = slice(1, None, None) slice_picker[i] = slice(1, None, None)
slice_inverter[i] = slice(None, 0, -1) slice_inverter[i] = slice(None, 0, -1)
else: else:
...@@ -174,7 +175,8 @@ class RGSpace(Space): ...@@ -174,7 +175,8 @@ class RGSpace(Space):
slice_inverter = tuple(slice_inverter) slice_inverter = tuple(slice_inverter)
try: try:
y.set_data(to_key=slice_picker, data=y, from_key=slice_inverter) y.set_data(to_key=slice_picker, data=y,
from_key=slice_inverter)
except(AttributeError): except(AttributeError):
y[slice_picker] = y[slice_inverter] y[slice_picker] = y[slice_inverter]
return y return y
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment