Commit f5c804b2 authored by Martin Reinecke's avatar Martin Reinecke

try to disentangle the logic a bit

parent 0cb13293
......@@ -428,26 +428,20 @@ def _register_cmaps():
plt.register_cmap(cmap=LinearSegmentedColormap("Plus Minus", pm_cmap))
def _plot(f, ax, **kwargs):
def _plot1D(f, ax, **kwargs):
import matplotlib.pyplot as plt
_register_cmaps()
if isinstance(f, Field):
f = [f]
if not isinstance(f, list):
raise TypeError("incorrect data type")
for i, fld in enumerate(f):
if not isinstance(fld, Field):
raise TypeError("incorrect data type")
if i == 0:
dom = fld.domain
if (len(dom) > 2) or (len(dom) < 1):
raise ValueError("input field must have either one domain or additionally an energy direction")
if (len(dom) != 1):
raise ValueError("input field must have exactly one domain")
else:
if fld.domain != dom:
raise ValueError("domain mismatch")
if not (isinstance(dom[0], PowerSpace) or
(isinstance(dom[0], RGSpace) and len(dom[0].shape) == 1)):
raise ValueError("PowerSpace or 1D RGSpace required")
dom = dom[0]
label = kwargs.pop("label", None)
if not isinstance(label, list):
......@@ -461,79 +455,22 @@ def _plot(f, ax, **kwargs):
if not isinstance(alpha, list):
alpha = [alpha] * len(f)
foo = kwargs.pop("norm", None)
norm = {} if foo is None else {'norm': foo}
ax.set_title(kwargs.pop("title", ""))
ax.set_xlabel(kwargs.pop("xlabel", ""))
ax.set_ylabel(kwargs.pop("ylabel", ""))
cmap = kwargs.pop("colormap", plt.rcParams['image.cmap'])
if isinstance(dom, DomainTuple):
if len(dom) == 2:
if isinstance(dom[1], RGSpace):
if isinstance(dom[0],RGSpace):
if len(dom[0].shape) == 2:
nx, ny = dom[0].shape
dx, dy = dom[0].distances
rgb = _rgb_data(f[0].to_global_data())
im = ax.imshow(
rgb, extent=[0, nx * dx, 0, ny * dy], origin="lower", **norm)
# from mpl_toolkits.axes_grid1 import make_axes_locatable
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im,cax=cax)
_limit_xy(**kwargs)
return
if isinstance(dom[0],(HPSpace, GLSpace)):
import pyHealpix
xsize = 800
res, mask, theta, phi = _mollweide_helper(xsize)
res = np.full(shape=res.shape+(3,), fill_value=1., dtype=np.float64)
rgb = _rgb_data(f[0].to_global_data())
if isinstance(dom[0], HPSpace):
ptg = np.empty((phi.size, 2), dtype=np.float64)
ptg[:, 0] = theta
ptg[:, 1] = phi
base = pyHealpix.Healpix_Base(int(np.sqrt(dom[0].size // 12)), "RING")
res[mask] = rgb[base.ang2pix(ptg)]
else:
ra = np.linspace(0, 2 * np.pi, dom[0].nlon + 1)
dec = pyHealpix.GL_thetas(dom[0].nlat)
ilat = _find_closest(dec, theta)
ilon = _find_closest(ra, phi)
ilon = np.where(ilon == dom[0].nlon, 0, ilon)
res[mask] = rgb[ilat * dom[0].nlon + ilon]
plt.axis('off')
plt.imshow(res, origin="lower")
return
dom = dom[0]
if isinstance(dom, RGSpace):
if len(dom.shape) == 1:
npoints = dom.shape[0]
dist = dom.distances[0]
xcoord = np.arange(npoints, dtype=np.float64)*dist
for i, fld in enumerate(f):
ycoord = fld.to_global_data()
plt.plot(xcoord, ycoord, label=label[i],
linewidth=linewidth[i], alpha=alpha[i])
_limit_xy(**kwargs)
if label != ([None]*len(f)):
plt.legend()
return
elif len(dom.shape) == 2:
nx, ny = dom.shape
dx, dy = dom.distances
im = ax.imshow(
f[0].to_global_data().T, extent=[0, nx*dx, 0, ny*dy],
vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower", **norm)
# from mpl_toolkits.axes_grid1 import make_axes_locatable
# divider = make_axes_locatable(ax)
# cax = divider.append_axes("right", size="5%", pad=0.05)
# plt.colorbar(im,cax=cax)
plt.colorbar(im)
_limit_xy(**kwargs)
return
npoints = dom.shape[0]
dist = dom.distances[0]
xcoord = np.arange(npoints, dtype=np.float64)*dist
for i, fld in enumerate(f):
ycoord = fld.to_global_data()
plt.plot(xcoord, ycoord, label=label[i],
linewidth=linewidth[i], alpha=alpha[i])
_limit_xy(**kwargs)
if label != ([None]*len(f)):
plt.legend()
return
elif isinstance(dom, PowerSpace):
plt.xscale('log')
plt.yscale('log')
......@@ -546,27 +483,107 @@ def _plot(f, ax, **kwargs):
if label != ([None]*len(f)):
plt.legend()
return
raise ValueError("Field type not(yet) supported")
def _plot2D(f, ax, **kwargs):
import matplotlib.pyplot as plt
dom = f.domain
if len(dom) > 2:
raise ValueError("DomainTuple can have at most two entries.")
# check for multifrequency plotting
have_rgb = False
if len(dom) == 2:
if (not isinstance(dom[1], RGSpace)) or len(dom[1].shape) != 1:
raise TypeError("need 1D RGSpace as second domain")
rgb = _rgb_data(f.to_global_data())
have_rgb = True
label = kwargs.pop("label", None)
foo = kwargs.pop("norm", None)
norm = {} if foo is None else {'norm': foo}
ax.set_title(kwargs.pop("title", ""))
ax.set_xlabel(kwargs.pop("xlabel", ""))
ax.set_ylabel(kwargs.pop("ylabel", ""))
dom = dom[0]
if not have_rgb:
cmap = kwargs.pop("colormap", plt.rcParams['image.cmap'])
if isinstance(dom, RGSpace):
nx, ny = dom.shape
dx, dy = dom.distances
if have_rgb:
im = ax.imshow(
rgb, extent=[0, nx*dx, 0, ny*dy], origin="lower", **norm)
else:
im = ax.imshow(
f.to_global_data().T, extent=[0, nx*dx, 0, ny*dy],
vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower", **norm)
plt.colorbar(im)
_limit_xy(**kwargs)
return
elif isinstance(dom, (HPSpace, GLSpace)):
import pyHealpix
xsize = 800
res, mask, theta, phi = _mollweide_helper(xsize)
if have_rgb:
res = np.full(shape=res.shape+(3,), fill_value=1., dtype=np.float64)
if isinstance(dom, HPSpace):
ptg = np.empty((phi.size, 2), dtype=np.float64)
ptg[:, 0] = theta
ptg[:, 1] = phi
base = pyHealpix.Healpix_Base(int(np.sqrt(f[0].size//12)), "RING")
res[mask] = f[0].to_global_data()[base.ang2pix(ptg)]
base = pyHealpix.Healpix_Base(int(np.sqrt(dom.size//12)), "RING")
if have_rgb:
res[mask] = rgb[base.ang2pix(ptg)]
else:
res[mask] = f.to_global_data()[base.ang2pix(ptg)]
else:
ra = np.linspace(0, 2*np.pi, dom.nlon+1)
dec = pyHealpix.GL_thetas(dom.nlat)
ilat = _find_closest(dec, theta)
ilon = _find_closest(ra, phi)
ilon = np.where(ilon == dom.nlon, 0, ilon)
res[mask] = f[0].to_global_data()[ilat*dom.nlon + ilon]
if have_rgb:
res[mask] = rgb[ilat*dom[0].nlon + ilon]
else:
res[mask] = f.to_global_data()[ilat*dom.nlon + ilon]
plt.axis('off')
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal")
if have_rgb:
plt.imshow(res, origin="lower")
else:
plt.imshow(res, vmin=kwargs.get("zmin"), vmax=kwargs.get("zmax"),
cmap=cmap, origin="lower")
plt.colorbar(orientation="horizontal")
return
raise ValueError("Field type not(yet) supported")
def _plot(f, ax, **kwargs):
_register_cmaps()
if isinstance(f, Field):
f = [f]
f = list(f)
if len(f) == 0:
raise ValueError("need something to plot")
if not isinstance(f[0], Field):
raise TypeError("incorrect data type")
dom1 = f[0].domain
if (len(dom1)==1 and
(isinstance(dom1[0],PowerSpace) or
(isinstance(dom1[0], RGSpace) and len(dom1[0].shape) == 1))):
_plot1D(f, ax, **kwargs)
return
else:
if len(f) != 1:
raise ValueError("need exactly one Field for 2D plot")
_plot2D(f[0], ax, **kwargs)
return
raise ValueError("Field type not(yet) supported")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment