Commit 8c929980 authored by csongor's avatar csongor

WIP: lm transformations in cython

parent 31a4e7e3
......@@ -4,7 +4,9 @@ from d2o import distributed_data_object
from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import GLSpace, LMSpace
import lm_transformation_factory as ltf
hp = gdi.get('healpy')
gl = gdi.get('libsharp_wrapper_gl')
......@@ -110,14 +112,27 @@ class GLLMTransformation(Transformation):
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
if self.domain.dtype == np.dtype('float32'):
inp = gl.map2alm_f(inp,
nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax)
if self.domain.dtype == np.dtype('complex128'):
inpReal = gl.map2alm(
np.real(inp).astype(np.float64, copy=False), nlat=nlat,
nlon=nlon, lmax=lmax, mmax=mmax)
inpImg = gl.map2alm(
np.imag(inp).astype(np.float64, copy=False), nlat=nlat,
nlon=nlon, lmax=lmax, mmax=mmax)
#TODO gl shouldn't depend on hp
lmaxArray, mmaxArray = hp.Alm.getlm(lmax=lmax)
inpReal = ltf.buildIdx(inpReal, lmaxArray, mmaxArray)
inpImg = ltf.buildIdx(inpImg, lmaxArray, mmaxArray)
inp = inpReal + inpImg * 1j
else:
inp = gl.map2alm(inp,
nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax)
if self.domain.dtype == np.dtype('float32'):
inp = gl.map2alm_f(inp,
nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax)
else:
inp = gl.map2alm(inp,
nlat=nlat, nlon=nlon,
lmax=lmax, mmax=mmax)
if slice_list == [slice(None, None)]:
return_val = inp
......
......@@ -5,6 +5,8 @@ from nifty.config import dependency_injector as gdi
import nifty.nifty_utilities as utilities
from nifty import HPSpace, LMSpace
import lm_transformation_factory as ltf
hp = gdi.get('healpy')
......@@ -106,9 +108,23 @@ class HPLMTransformation(Transformation):
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
inp = hp.map2alm(inp.astype(np.float64, copy=False),
lmax=lmax, mmax=mmax, iter=niter, pol=True,
use_weights=False, datapath=None)
if self.domain.dtype == np.dtype('complex128'):
inpReal = hp.map2alm(
np.real(inp).astype(np.float64, copy=False),
lmax=lmax, mmax=mmax, iter=niter, pol=True,
use_weights=False, datapath=None)
inpImg = hp.map2alm(
np.imag(inp).astype(np.float64, copy=False),
lmax=lmax, mmax=mmax, iter=niter, pol=True,
use_weights=False, datapath=None)
lmaxArray, mmaxArray = hp.Alm.getlm(lmax=lmax)
inpReal = ltf.buildIdx(inpReal,lmaxArray, mmaxArray)
inpImg = ltf.buildIdx(inpImg,lmaxArray, mmaxArray)
inp = inpReal + inpImg * 1j
else:
inp = hp.map2alm(inp.astype(np.float64, copy=False),
lmax=lmax, mmax=mmax, iter=niter, pol=True,
use_weights=False, datapath=None)
if slice_list == [slice(None, None)]:
return_val = inp
......
import numpy as np
cimport numpy as np
cpdef buildIdx(np.ndarray[np.complex128_t, ndim=1] nr, np.ndarray[np.int_t] l, np.ndarray[np.int_t] m, np.int_t lmax):
cdef np.int size = (lmax+1)*(lmax+1)
cdef np.ndarray res=np.zeros([size], dtype=np.complex128)
cdef np.ndarray final=np.zeros([size], dtype=np.float64)
res[0:lmax+1] = nr[0:lmax+1]
cdef np.ndarray resL=np.zeros([size], dtype=np.int)
resL[0:lmax+1] = np.arange(lmax+1)
cdef np.ndarray resM=np.zeros([size], dtype=np.int)
for i in xrange(len(nr)-lmax-1):
res[i*2+lmax+1] = nr[i+lmax+1]
res[i*2+lmax+2] = np.conjugate(nr[i+lmax+1])
resL[i*2+lmax+1] = l[i+lmax+1]
resL[i*2+lmax+2] = l[i+lmax+1]
resM[i*2+lmax+1] = m[i+lmax+1]
resM[i*2+lmax+2] = -m[i+lmax+1]
final = realify(res,resL, resM)
return final, resL, resM
cpdef buildLm(np.ndarray[np.float64_t, ndim=1] nr, np.ndarray[np.int_t] l, np.ndarray[np.int_t] m, np.int_t lmax):
cdef np.int size = (len(nr)-lmax-1)/2+lmax+1
cdef np.ndarray res=np.zeros([size], dtype=np.complex128)
cdef np.ndarray temp=np.zeros([len(nr)], dtype=np.complex128)
res[0:lmax+1] = nr[0:lmax+1]
cdef np.ndarray resL=np.zeros([size], dtype=np.int)
resL[0:lmax+1] = np.arange(lmax+1)
cdef np.ndarray resM=np.zeros([size], dtype=np.int)
temp = inverseRealify(nr, l, m)
for i in xrange(0,len(temp)-lmax-1,2):
res[i/2+lmax+1] = temp[i+lmax+1]
resL[i/2+lmax+1] = l[i+lmax+1]
resM[i/2+lmax+1] = m[i+lmax+1]
return res,resL,resM
cpdef np.ndarray[np.float64_t, ndim=1] realify(np.ndarray[np.complex128_t, ndim=1] nr, np.ndarray[np.int_t] l, np.ndarray[np.int_t] m):
cdef np.ndarray resi=np.zeros([len(nr)], dtype=np.float64)
for i in xrange(len(nr)):
if m[i]<0:
resi[i]=np.sqrt(2)*np.imag(nr[i-1])*(-1)**(m[i]*m[i])
elif m[i]>0:
resi[i]=np.sqrt(2)*np.real(nr[i])*(-1)**(m[i]*m[i])
else:
resi[i]=np.real(nr[i])
return resi
cpdef np.ndarray[np.complex128_t, ndim=1] inverseRealify(np.ndarray[np.float64_t, ndim=1] nr, np.ndarray[np.int_t] l, np.ndarray[np.int_t] m):
cdef np.ndarray resi=np.zeros([len(nr)], dtype=np.complex128)
for i in xrange(len(nr)):
if m[i]<0:
resi[i]=1/np.sqrt(2)*(nr[i-1]-1j*nr[i])
elif m[i]>0:
resi[i]=(-1)**m[i]/np.sqrt(2)*(nr[i]+1j*nr[i+1])
else:
resi[i]=np.real(nr[i])
return resi
This diff is collapsed.
# -*- coding: utf-8 -*-
#import numpy as np
#from nifty.config import about
#from nifty.spaces.space import SpaceParadict
#
#
#class LMSpaceParadict(SpaceParadict):
#
# def __init__(self, lmax, mmax):
# SpaceParadict.__init__(self, lmax=lmax)
# if mmax is None:
# mmax = -1
# self['mmax'] = mmax
#
# def __setitem__(self, key, arg):
# if key not in ['lmax', 'mmax']:
# raise ValueError(about._errors.cstring(
# "ERROR: Unsupported LMSpace parameter: " + key))
#
# if key == 'lmax':
# temp = np.int(arg)
# if temp < 1:
# raise ValueError(about._errors.cstring(
# "ERROR: lmax: nonpositive number."))
# # exception lmax == 2 (nside == 1)
# if (temp % 2 == 0) and (temp > 2):
# about.warnings.cprint(
# "WARNING: unrecommended parameter (lmax <> 2*n+1).")
# try:
# if temp < self['mmax']:
# about.warnings.cprint(
# "WARNING: mmax parameter set to lmax.")
# self['mmax'] = temp
# if (temp != self['mmax']):
# about.warnings.cprint(
# "WARNING: unrecommended parameter set (mmax <> lmax).")
# except:
# pass
# elif key == 'mmax':
# temp = int(arg)
# if (temp < 1) or(temp > self['lmax']):
# about.warnings.cprint(
# "WARNING: mmax parameter set to default.")
# temp = self['lmax']
# if(temp != self['lmax']):
# about.warnings.cprint(
# "WARNING: unrecommended parameter set (mmax <> lmax).")
#
# self.parameters.__setitem__(key, temp)
......@@ -22,8 +22,14 @@
from setuptools import setup, find_packages
import os
from Cython.Build import cythonize
exec(open('nifty/version.py').read())
ext_modules = cythonize(
"nifty/operators/fft_operator/transformations/lm_transformation_factory"
".pyx")
setup(name="ift_nifty",
version=__version__,
author="Theo Steininger",
......@@ -33,6 +39,7 @@ setup(name="ift_nifty",
packages=find_packages(),
package_dir={"nifty": "nifty"},
zip_safe=False,
ext_modules=ext_modules,
dependency_links=[
'git+https://gitlab.mpcdf.mpg.de/ift/keepers.git#egg=keepers',
'git+https://gitlab.mpcdf.mpg.de/ift/d2o.git#egg=d2o'],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment