Commit 4606ccbf authored by Jait Dixit's avatar Jait Dixit
Browse files

WIP: Add weight method to GLSpace and HPSpace

parent 8d6df7d5
...@@ -11,9 +11,8 @@ from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES ...@@ -11,9 +11,8 @@ from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.spaces.lm_space import LMSpace from nifty.spaces.lm_space import LMSpace
from nifty.spaces.space import Space from nifty.spaces.space import Space
from nifty.config import about, \ from nifty.config import about, nifty_configuration as gc,\
nifty_configuration as gc, \ dependency_injector as gdi
dependency_injector as gdi
from gl_space_paradict import GLSpaceParadict from gl_space_paradict import GLSpaceParadict
from nifty.nifty_random import random from nifty.nifty_random import random
...@@ -132,6 +131,11 @@ class GLSpace(Space): ...@@ -132,6 +131,11 @@ class GLSpace(Space):
return np.sum(self.paradict['nlon'] * np.array(self.distances[0])) return np.sum(self.paradict['nlon'] * np.array(self.distances[0]))
def weight(self, x, power=1, axes=None, inplace=False): def weight(self, x, power=1, axes=None, inplace=False):
# check if the axes provided are valid given the input shape
if axes is not None and \
not all(axis in range(len(x.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
weight = np.array(list( weight = np.array(list(
itertools.chain.from_iterable( itertools.chain.from_iterable(
itertools.repeat(x ** power, self.paradict['nlon']) itertools.repeat(x ** power, self.paradict['nlon'])
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
""" """
from __future__ import division from __future__ import division
import itertools
import numpy as np import numpy as np
import pylab as pl import pylab as pl
...@@ -42,7 +43,7 @@ from nifty.spaces.lm_space import LMSpace ...@@ -42,7 +43,7 @@ from nifty.spaces.lm_space import LMSpace
from nifty.spaces.space import Space from nifty.spaces.space import Space
from nifty.config import about, nifty_configuration as gc,\ from nifty.config import about, nifty_configuration as gc, \
dependency_injector as gdi dependency_injector as gdi
from hp_space_paradict import HPSpaceParadict from hp_space_paradict import HPSpaceParadict
from nifty.nifty_random import random from nifty.nifty_random import random
...@@ -148,7 +149,34 @@ class HPSpace(Space): ...@@ -148,7 +149,34 @@ class HPSpace(Space):
return np.int(12 * self.paradict['nside'] ** 2) return np.int(12 * self.paradict['nside'] ** 2)
def weight(self, x, power=1, axes=None, inplace=False): def weight(self, x, power=1, axes=None, inplace=False):
pass # check if the axes provided are valid given the input shape
if axes is not None and \
not all(axis in range(len(x.shape)) for axis in axes):
raise ValueError("ERROR: Provided axes does not match array shape")
weight = np.array(list(
itertools.chain.from_iterable(
itertools.repeat(
(4 * np.pi / 12 * self.paradict['nside'] ** 2) ** power,
12 * self.paradict['nside'] ** 2
)
)
))
if axes is not None:
# reshape the weight array to match the input shape
new_shape = np.ones(x.shape)
for index in range(len(axes)):
new_shape[index] = len(weight)
weight = weight.reshape(new_shape)
if inplace:
x *= weight
result_x = x
else:
result_x = x * weight
return result_x
def get_plot(self, x, title="", vmin=None, vmax=None, power=False, unit="", def get_plot(self, x, title="", vmin=None, vmax=None, power=False, unit="",
norm=None, cmap=None, cbar=True, other=None, legend=False, norm=None, cmap=None, cbar=True, other=None, legend=False,
...@@ -256,7 +284,7 @@ class HPSpace(Space): ...@@ -256,7 +284,7 @@ class HPSpace(Space):
color=[max(0.0, 1.0 - (2 * ii / imax) ** 2), color=[max(0.0, 1.0 - (2 * ii / imax) ** 2),
0.5 * ((2 * ii - imax) / imax) 0.5 * ((2 * ii - imax) / imax)
** 2, max(0.0, 1.0 - ( ** 2, max(0.0, 1.0 - (
2 * (ii - imax) / imax) ** 2)], 2 * (ii - imax) / imax) ** 2)],
label="graph " + str(ii + 1), linestyle='-', label="graph " + str(ii + 1), linestyle='-',
linewidth=1.0, zorder=-ii) linewidth=1.0, zorder=-ii)
if (mono): if (mono):
...@@ -266,7 +294,8 @@ class HPSpace(Space): ...@@ -266,7 +294,8 @@ class HPSpace(Space):
0.5 * ((2 * ii - imax) / imax) ** 2, 0.5 * ((2 * ii - imax) / imax) ** 2,
max( max(
0.0, 1.0 - ( 0.0, 1.0 - (
2 * (ii - imax) / imax) ** 2)], 2 * (
ii - imax) / imax) ** 2)],
marker='o', cmap=None, norm=None, vmin=None, marker='o', cmap=None, norm=None, vmin=None,
vmax=None, alpha=None, linewidths=None, vmax=None, alpha=None, linewidths=None,
verts=None, zorder=-ii) verts=None, zorder=-ii)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment