Commit 3e85f6ff authored by Martin Reinecke's avatar Martin Reinecke

simplify get_distance_array

parent 07fc9bc8
Pipeline #17947 passed with stage
in 4 minutes and 54 seconds
......@@ -31,9 +31,7 @@
from __future__ import division
from builtins import range
from functools import reduce
import numpy as np
from ..space import Space
......@@ -83,10 +81,9 @@ class RGSpace(Space):
# ---Overwritten properties and methods---
def __init__(self, shape, distances=None, harmonic=False):
self._harmonic = bool(harmonic)
super(RGSpace, self).__init__()
self._harmonic = bool(harmonic)
self._shape = self._parse_shape(shape)
self._distances = self._parse_distances(distances)
......@@ -134,35 +131,17 @@ class RGSpace(Space):
if (not self.harmonic):
raise NotImplementedError
shape = self.shape
slice_of_first_dimension = slice(0, shape[0])
dists = self._distance_array_helper(slice_of_first_dimension)
return dists
def _distance_array_helper(self, slice_of_first_dimension):
dk = self.distances
shape = self.shape
inds = []
for a in shape:
inds += [slice(0, a)]
cords = np.ogrid[inds]
dists = (cords[0] - shape[0]//2)*dk[0]
dists *= dists
dists = np.fft.ifftshift(dists)
# only save the individual slice
dists = dists[slice_of_first_dimension]
for ii in range(1, len(shape)):
temp = (cords[ii] - shape[ii] // 2) * dk[ii]
temp *= temp
temp = np.fft.ifftshift(temp)
dists = dists + temp
dists = np.sqrt(dists)
return dists
res = np.arange(self.shape[0], dtype=np.float64)
res = np.minimum(res, self.shape[0]-res)*self.distances[0]
if len(self.shape) == 1:
return res
res *= res
for i in range(1, len(self.shape)):
tmp = np.arange(self.shape[i], dtype=np.float64)
tmp = np.minimum(tmp, self.shape[i]-tmp)*self.distances[i]
tmp *= tmp
res = np.add.outer(res, tmp)
return np.sqrt(res)
def get_unique_distances(self):
if (not self.harmonic):
......@@ -184,7 +163,7 @@ class RGSpace(Space):
tmp[t2] = True
return np.sqrt(np.nonzero(tmp)[0])*self.distances[0]
else: # do it the hard way
tmp = self.get_distance_array('not').unique() # expensive!
tmp = self.get_distance_array().unique() # expensive!
tol = 1e-12*tmp[-1]
# remove all points that are closer than tol to their right
# neighbors.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment