Commit ebbcb0bb authored by Jakob Knollmueller's avatar Jakob Knollmueller
Browse files

plotting multi-frequency maps for RG, HP and GL space added

parent 071fc76e
......@@ -28,6 +28,7 @@ from .domains.gl_space import GLSpace
from .domains.hp_space import HPSpace
from .domains.power_space import PowerSpace
from .domains.rg_space import RGSpace
from .domain_tuple import DomainTuple
from .field import Field
# relevant properties:
......@@ -58,6 +59,25 @@ def _mollweide_helper(xsize):
return res, mask, theta, phi
def _rgb_data(spectral_cube):
def _eye_sensitivity(energy_bins):
a = np.arange(0, 1, 1 / energy_bins)
rgb = np.empty((3, energy_bins))
rgb[0] = np.exp(-(a - 5 / 12) ** 2 / (2 * (2.5 / 12) ** 2))
rgb[1] = np.exp(-(a - 6.5 / 12) ** 2 / (2 * (2 / 12) ** 2))
rgb[2] = np.exp(-(a - 10 / 12) ** 2 / (2 * (1 / 12) ** 2))
rgb[0] /= rgb[0].max()
rgb[1] /= rgb[1].max()
rgb[2] /= rgb[2].max()
return rgb
rgb = _eye_sensitivity(spectral_cube.shape[-1])
rgb_data = np.tensordot(spectral_cube, rgb, axes=[-1, -1])
rgb_data = np.log(rgb_data)
rgb_data -= rgb_data.min()
rgb_data /= rgb_data.max()
return rgb_data
def _find_closest(A, target):
# A must be sorted
idx = np.clip(A.searchsorted(target), 1, len(A)-1)
......@@ -168,8 +188,8 @@ def _plot(f, ax, **kwargs):
raise TypeError("incorrect data type")
if i == 0:
dom = fld.domain
if len(dom) != 1:
raise ValueError("input field must have exactly one domain")
if (len(dom) > 2) or (len(dom) < 1):
raise ValueError("input field must have either one domain or additionally an energy direction")
else:
if fld.domain != dom:
raise ValueError("domain mismatch")
......@@ -192,11 +212,49 @@ def _plot(f, ax, **kwargs):
foo = kwargs.pop("norm", None)
norm = {} if foo is None else {'norm': foo}
dom = dom[0]
ax.set_title(kwargs.pop("title", ""))
ax.set_xlabel(kwargs.pop("xlabel", ""))
ax.set_ylabel(kwargs.pop("ylabel", ""))
cmap = kwargs.pop("colormap", plt.rcParams['image.cmap'])
if isinstance(dom, DomainTuple):
if isinstance(dom[1], RGSpace):
if isinstance(dom[0],RGSpace):
if len(dom[0].shape) == 2:
nx, ny = dom[0].shape
dx, dy = dom[0].distances
rgb = _rgb_data(f[0].to_global_data())
im = ax.imshow(
rgb, extent=[0, nx * dx, 0, ny * dy], origin="lower", **norm)
# from mpl_toolkits.axes_grid1 import make_axes_locatable
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im,cax=cax)
_limit_xy(**kwargs)
return
if isinstance(dom[0],(HPSpace, GLSpace)):
import pyHealpix
xsize = 800
res, mask, theta, phi = _mollweide_helper(xsize)
res = np.full(shape=res.shape+(3,), fill_value=1., dtype=np.float64)
rgb = _rgb_data(f[0].to_global_data())
if isinstance(dom[0], HPSpace):
ptg = np.empty((phi.size, 2), dtype=np.float64)
ptg[:, 0] = theta
ptg[:, 1] = phi
base = pyHealpix.Healpix_Base(int(np.sqrt(dom[0].size // 12)), "RING")
res[mask] = rgb[base.ang2pix(ptg)]
else:
ra = np.linspace(0, 2 * np.pi, dom[0].nlon + 1)
dec = pyHealpix.GL_thetas(dom[0].nlat)
ilat = _find_closest(dec, theta)
ilon = _find_closest(ra, phi)
ilon = np.where(ilon == dom[0].nlon, 0, ilon)
res[mask] = rgb[ilat * dom[0].nlon + ilon]
plt.axis('off')
plt.imshow(res, origin="lower")
return
else:
dom = dom[0]
if isinstance(dom, RGSpace):
if len(dom.shape) == 1:
npoints = dom.shape[0]
......@@ -258,7 +316,6 @@ def _plot(f, ax, **kwargs):
cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal")
return
raise ValueError("Field type not(yet) supported")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment