Skip to content
Snippets Groups Projects
Unverified Commit 2fa69f86 authored by Gordian Edenhofer's avatar Gordian Edenhofer Committed by GitHub
Browse files

Merge pull request #39 from NIFTy-PPL/cf_grid_numpy

re.correlated_field: Numpy grid construction
parents 49ca3958 a82e5608
No related branches found
No related tags found
No related merge requests found
Pipeline #238023 passed
......@@ -53,15 +53,15 @@ def get_sht(nside, axis, lmax, mmax, nthreads):
def _unique_mode_distributor(m_length, uniqueness_rtol=1e-12):
# Construct an array of unique mode lengths
um = jnp.unique(m_length)
um = np.unique(m_length)
tol = uniqueness_rtol * um[-1]
um = um[jnp.diff(jnp.append(um, 2 * um[-1])) > tol]
um = um[np.diff(np.append(um, 2 * um[-1])) > tol]
# Group modes based on their length and store the result as power
# distributor
binbounds = 0.5 * (um[:-1] + um[1:])
m_length_idx = jnp.searchsorted(binbounds, m_length)
m_count = jnp.bincount(m_length_idx.ravel(), minlength=um.size)
if jnp.any(m_count == 0) or um.shape != m_count.shape:
m_length_idx = np.searchsorted(binbounds, m_length)
m_count = np.bincount(m_length_idx.ravel(), minlength=um.size)
if np.any(m_count == 0) or um.shape != m_count.shape:
raise RuntimeError("invalid harmonic mode(s) encountered")
return m_length_idx, um, m_count
......@@ -160,17 +160,17 @@ def get_fourier_mode_distributor(
shape = (shape,) if isinstance(shape, int) else tuple(shape)
# Compute length of modes
mspc_distances = 1.0 / (jnp.array(shape) * jnp.array(distances))
m_length = jnp.arange(shape[0], dtype=jnp.float64)
m_length = jnp.minimum(m_length, shape[0] - m_length) * mspc_distances[0]
mspc_distances = 1.0 / (np.array(shape) * np.array(distances))
m_length = np.arange(shape[0])
m_length = np.minimum(m_length, shape[0] - m_length) * mspc_distances[0]
if len(shape) != 1:
m_length *= m_length
for i in range(1, len(shape)):
tmp = jnp.arange(shape[i], dtype=jnp.float64)
tmp = jnp.minimum(tmp, shape[i] - tmp) * mspc_distances[i]
tmp = np.arange(shape[i])
tmp = np.minimum(tmp, shape[i] - tmp) * mspc_distances[i]
tmp *= tmp
m_length = jnp.expand_dims(m_length, axis=-1) + tmp
m_length = jnp.sqrt(m_length)
m_length = np.expand_dims(m_length, axis=-1) + tmp
m_length = np.sqrt(m_length)
return _unique_mode_distributor(m_length, uniqueness_rtol=uniqueness_rtol)
......@@ -226,8 +226,8 @@ LMGrid = namedtuple(
def _log_modes(m_length):
um = m_length.copy()
um = um.at[1:].set(jnp.log(um[1:]))
um = um.at[1:].add(-um[1])
um[1:] = np.log(um[1:])
um[1:] -= um[1]
assert um[0] == 0.0
log_vol = um[2:] - um[1:-1]
assert um.shape[0] - 2 == log_vol.shape[0]
......@@ -242,9 +242,9 @@ def make_grid(
# Pre-compute lengths of modes and indices for distributing power
if harmonic_type.lower() == "fourier":
distances = tuple(np.broadcast_to(distances, jnp.shape(shape)))
distances = tuple(np.broadcast_to(distances, np.shape(shape)))
totvol = jnp.prod(jnp.array(shape) * jnp.array(distances))
totvol = np.prod(np.array(shape) * np.array(distances))
m_length_idx, m_length, m_count = get_fourier_mode_distributor(shape, distances)
um, log_vol = _log_modes(m_length)
harmonic_grid = RegularFourierGrid(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment