BoxSize = ...
PMGRID = ...
cell_len = BoxSize / PMGRID
stride = [PMGRID**2, PMGRID, 1]
ngrid = [PMGRID] * 3


def FDM_FI(x, y, z):
	assert 0 <= x < PMGRID and 0 <= y < PMGRID and 0 <= z < PMGRID
	return PMGRID * (PMGRID * x + y) + z

def fdm_idx_1to3(idx):
	result = [-1] * 3
	for i in range(len(result)):
		result[i] = idx // stride[i]
		assert result[i] < ngrid[i]
		idx -= result[i] * stride[i]
	return result

def fdm_idx_to_pos(cell_len, idx):
	result = [-1] * 3
	ipos = fdm_idx_1to3(idx)
	for i in range(len(result)):
		result[i] = (ipos[i] + 0.5) * cell_len
	return result

def subdivide_evenly(N, pieces, index):
	avg = (N - 1) // pieces + 1
	exc = pieces * avg - N
	indexlastsection = pieces - exc
	if index < indexlastsection:
		first = index * avg
		count = avg
	else:
		first = index * avg - (index - indexlastsection)
		count = avg - 1
	return first, count

def slab_to_task(NTask):
	first_slab_x_of_task = []
	slabs_x_per_task = []
	slab_to_task = []
	for task in range(NTask):
		slabstart_x, nslab_x = subdivide_evenly(PMGRID, NTask, task)
		first_slab_x_of_task.append(slabstart_x)
		slabs_x_per_task.append(nslab_x)
		slab_to_task += [task] * nslab_x
	assert len(slab_to_task) == PMGRID
	return slab_to_task

def fdm_idx_to_task(idx, NTask):
	ix, *_ = fdm_idx_1to3(idx)
	return slab_to_task(NTask)[ix]

def dist_periodic_wrap_array(box_size, x):
	num_items, num_dims = x.shape
	for i in range(num_items):
		for j in range(num_dims):
			while x[i, j] < -box_size / 2:
				x[i, j] += box_size
			while x[i, j] >= box_size / 2:
				x[i, j] -= box_size
	return x

def coord_periodic_wrap_vec(box_size, x):
	for i in range(len(x)):
		while x[i] < 0:
			x[i] += box_size
		while x[i] >= box_size:
			x[i] -= box_size
	return x

def center_of_mass_particles(pos, box_size, firstpos):
	pos_relative_to_firstpos = dist_periodic_wrap_array(
		box_size, pos - firstpos
	)
	# all particles have the same mass here, so the mass cancels
	center_sum_g = numpy.sum(pos_relative_to_firstpos, axis=0)
	center_num_g = len(pos)
	com = coord_periodic_wrap_vec(
		box_size, center_sum_g / center_num_g + firstpos
	) if center_num_g != 0 else None
	return com

## example:
# idx = ...
# ipos = numpy.full((len(idx), 3), -1)
# pos = numpy.full((len(idx), 3), -1.)
# for i in range(len(pos)):
# 	ipos[i, :] = fdm_idx_1to3(idx[i][0])
# 	pos[i, :] = fdm_idx_to_pos(cell_len, idx[i][0])
# com = center_of_mass_particles(pos, BoxSize, pos[0])