Commit 4606ccbf authored by Jait Dixit's avatar Jait Dixit

WIP: Add weight method to GLSpace and HPSpace

parent 8d6df7d5
......@@ -11,9 +11,8 @@ from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.spaces.lm_space import LMSpace
from nifty.spaces.space import Space
from nifty.config import about, \
nifty_configuration as gc, \
dependency_injector as gdi
from nifty.config import about, nifty_configuration as gc,\
dependency_injector as gdi
from gl_space_paradict import GLSpaceParadict
from nifty.nifty_random import random
......@@ -132,6 +131,11 @@ class GLSpace(Space):
return np.sum(self.paradict['nlon'] * np.array(self.distances[0]))
def weight(self, x, power=1, axes=None, inplace=False):
# check if the axes provided are valid given the input shape
if axes is not None and \
not all(axis in range(len(x.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
weight = np.array(list(
itertools.chain.from_iterable(
itertools.repeat(x ** power, self.paradict['nlon'])
......
......@@ -33,6 +33,7 @@
"""
from __future__ import division
import itertools
import numpy as np
import pylab as pl
......@@ -42,7 +43,7 @@ from nifty.spaces.lm_space import LMSpace
from nifty.spaces.space import Space
from nifty.config import about, nifty_configuration as gc,\
from nifty.config import about, nifty_configuration as gc, \
dependency_injector as gdi
from hp_space_paradict import HPSpaceParadict
from nifty.nifty_random import random
......@@ -148,7 +149,34 @@ class HPSpace(Space):
return np.int(12 * self.paradict['nside'] ** 2)
def weight(self, x, power=1, axes=None, inplace=False):
pass
# check if the axes provided are valid given the input shape
if axes is not None and \
not all(axis in range(len(x.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
weight = np.array(list(
itertools.chain.from_iterable(
itertools.repeat(
(4 * np.pi / 12 * self.paradict['nside'] ** 2) ** power,
12 * self.paradict['nside'] ** 2
)
)
))
if axes is not None:
# reshape the weight array to match the input shape
new_shape = np.ones(x.shape)
for index in range(len(axes)):
new_shape[index] = len(weight)
weight = weight.reshape(new_shape)
if inplace:
x *= weight
result_x = x
else:
result_x = x * weight
return result_x
def get_plot(self, x, title="", vmin=None, vmax=None, power=False, unit="",
norm=None, cmap=None, cbar=True, other=None, legend=False,
......@@ -256,7 +284,7 @@ class HPSpace(Space):
color=[max(0.0, 1.0 - (2 * ii / imax) ** 2),
0.5 * ((2 * ii - imax) / imax)
** 2, max(0.0, 1.0 - (
2 * (ii - imax) / imax) ** 2)],
2 * (ii - imax) / imax) ** 2)],
label="graph " + str(ii + 1), linestyle='-',
linewidth=1.0, zorder=-ii)
if (mono):
......@@ -266,7 +294,8 @@ class HPSpace(Space):
0.5 * ((2 * ii - imax) / imax) ** 2,
max(
0.0, 1.0 - (
2 * (ii - imax) / imax) ** 2)],
2 * (
ii - imax) / imax) ** 2)],
marker='o', cmap=None, norm=None, vmin=None,
vmax=None, alpha=None, linewidths=None,
verts=None, zorder=-ii)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment