From 007da99a1520a8531781e2c5aeb723891d378c9c Mon Sep 17 00:00:00 2001
From: csongor <csongor.varady@gmail.com>
Date: Fri, 2 Sep 2016 10:33:30 -0700
Subject: [PATCH] Taking care of types

---
 .../transformations/gllmtransformation.py     | 19 ++---
 .../lm_transformation_factory.pyx             | 78 +++++++++++++++++--
 .../transformations/lmgltransformation.py     | 21 ++---
 3 files changed, 91 insertions(+), 27 deletions(-)

diff --git a/nifty/operators/fft_operator/transformations/gllmtransformation.py b/nifty/operators/fft_operator/transformations/gllmtransformation.py
index 2ab318f9c..e686b0ffc 100644
--- a/nifty/operators/fft_operator/transformations/gllmtransformation.py
+++ b/nifty/operators/fft_operator/transformations/gllmtransformation.py
@@ -111,23 +111,18 @@ class GLLMTransformation(Transformation):
                     return_val = np.empty_like(temp_val)
                 inp = temp_val[slice_list]
 
-            if self.domain.dtype == np.dtype('complex128'):
-                inpReal = gl.map2alm(
+            if inp.dtype >= np.dtype('complex64'):
+                inpReal = self.GlMap2Alm(
                     np.real(inp).astype(np.float64, copy=False), nlat=nlat,
                     nlon=nlon, lmax=lmax, mmax=mmax)
-                inpImg = gl.map2alm(
+                inpImg = self.GlMap2Alm(
                     np.imag(inp).astype(np.float64, copy=False), nlat=nlat,
                     nlon=nlon, lmax=lmax, mmax=mmax)
                 inpReal = ltf.buildIdx(inpReal, lmax=lmax)
                 inpImg = ltf.buildIdx(inpImg, lmax=lmax)
                 inp = inpReal + inpImg * 1j
             else:
-                if self.domain.dtype == np.dtype('float32'):
-                    inp = gl.map2alm_f(inp,
-                                       nlat=nlat, nlon=nlon,
-                                       lmax=lmax, mmax=mmax)
-                else:
-                    inp = gl.map2alm(inp,
+                inp = self.GlMap2Alm(inp,
                                      nlat=nlat, nlon=nlon,
                                      lmax=lmax, mmax=mmax)
                 inp = ltf.buildIdx(inp, lmax=lmax)
@@ -144,3 +139,9 @@ class GLLMTransformation(Transformation):
             return_val = return_val.astype(self.codomain.dtype, copy=False)
 
         return return_val
+
+    def GlMap2Alm(self, inp, **kwargs):
+        if inp.dtype == np.dtype('float32'):
+            return gl.map2alm_f(inp, kwargs)
+        else:
+            return gl.map.alm(inp, kwargs)
diff --git a/nifty/operators/fft_operator/transformations/lm_transformation_factory.pyx b/nifty/operators/fft_operator/transformations/lm_transformation_factory.pyx
index d7f02c437..a561c5b12 100644
--- a/nifty/operators/fft_operator/transformations/lm_transformation_factory.pyx
+++ b/nifty/operators/fft_operator/transformations/lm_transformation_factory.pyx
@@ -1,7 +1,45 @@
 import numpy as np
 cimport numpy as np
 
-cpdef buildIdx(np.ndarray[np.complex128_t, ndim=1] nr, np.int_t lmax):
+def buildLm(inp, **kwargs):
+    if inp.dtype == np.dtype('float32'):
+        return _buildLm_f(inp, **kwargs)
+    else:
+        return _buildLm(inp, **kwargs)
+
+def buildIdx(inp, **kwargs):
+    if inp.dtype == np.dtype('complex64'):
+        return _buildIdx_f(inp, **kwargs)
+    else:
+        return _buildIdx(inp, **kwargs)
+
+cpdef np.ndarray[np.float32_t, ndim=1]  _buildIdx_f(np.ndarray[np
+.complex64_t, ndim=1] nr, np.int_t lmax):
+    cdef np.int size = (lmax+1)*(lmax+1)
+
+    cdef np.ndarray res=np.zeros([size], dtype=np.complex64)
+    cdef np.ndarray final=np.zeros([size], dtype=np.float32)
+    res[0:lmax+1] = nr[0:lmax+1]
+
+    for i in xrange(len(nr)-lmax-1):
+        res[i*2+lmax+1] = nr[i+lmax+1]
+        res[i*2+lmax+2] = np.conjugate(nr[i+lmax+1])
+    final = _realify_f(res, lmax)
+    return final
+
+cpdef np.ndarray[np.float32_t, ndim=1] _realify_f(np.ndarray[np.complex64_t, ndim=1] nr, np.int_t lmax):
+    cdef np.ndarray resi=np.zeros([len(nr)], dtype=np.float32)
+
+    resi[0:lmax+1] = np.real(nr[0:lmax+1])
+
+    for i in xrange(lmax+1,len(nr),2):
+        mi =  int(np.ceil(((2*lmax+1)-np.sqrt((2*lmax+1)*(2*lmax+1)-4*(i-lmax)+1))/2))
+        resi[i]=np.sqrt(2)*np.real(nr[i])*(-1)**(mi*mi)
+        resi[i+1]=np.sqrt(2)*np.imag(nr[i])*(-1)**(mi*mi)
+    return resi
+
+cpdef np.ndarray[np.float64_t, ndim=1]  _buildIdx(np.ndarray[np.complex128_t,
+ ndim=1] nr, np.int_t lmax):
     cdef np.int size = (lmax+1)*(lmax+1)
 
     cdef np.ndarray res=np.zeros([size], dtype=np.complex128)
@@ -11,40 +49,64 @@ cpdef buildIdx(np.ndarray[np.complex128_t, ndim=1] nr, np.int_t lmax):
     for i in xrange(len(nr)-lmax-1):
         res[i*2+lmax+1] = nr[i+lmax+1]
         res[i*2+lmax+2] = np.conjugate(nr[i+lmax+1])
-    final = realify(res, lmax)
+    final = _realify(res, lmax)
     return final
 
-cpdef np.ndarray[np.float64_t, ndim=1] realify(np.ndarray[np.complex128_t, ndim=1] nr, np.int_t lmax):
+cpdef np.ndarray[np.float64_t, ndim=1] _realify(np.ndarray[np.complex128_t, ndim=1] nr, np.int_t lmax):
     cdef np.ndarray resi=np.zeros([len(nr)], dtype=np.float64)
 
     resi[0:lmax+1] = np.real(nr[0:lmax+1])
 
     for i in xrange(lmax+1,len(nr),2):
-        # m calculation print i,(i-lmax)/2+1,(2*lmax+1)*(2*lmax+1)-4*(i-lmax)+1
         mi =  int(np.ceil(((2*lmax+1)-np.sqrt((2*lmax+1)*(2*lmax+1)-4*(i-lmax)+1))/2))
         resi[i]=np.sqrt(2)*np.real(nr[i])*(-1)**(mi*mi)
         resi[i+1]=np.sqrt(2)*np.imag(nr[i])*(-1)**(mi*mi)
     return resi
 
-cpdef buildLm(np.ndarray[np.float64_t, ndim=1] nr, np.int_t lmax):
+cpdef np.ndarray[np.complex64_t, ndim=1] _buildLm_f(np.ndarray[np.float32_t,
+ndim=1] nr, np.int_t lmax):
+    cdef np.int size = (len(nr)-lmax-1)/2+lmax+1
+
+    cdef np.ndarray res=np.zeros([size], dtype=np.complex64)
+    cdef np.ndarray temp=np.zeros([len(nr)], dtype=np.complex64)
+    res[0:lmax+1] = nr[0:lmax+1]
+
+    temp = _inverseRealify_f(nr, lmax)
+
+    for i in xrange(0,len(temp)-lmax-1,2):
+        res[i/2+lmax+1] = temp[i+lmax+1]
+    return res
+
+cpdef np.ndarray[np.complex64_t, ndim=1] _inverseRealify_f(np.ndarray[np.float32_t, ndim=1] nr, np.int_t lmax):
+    cdef np.ndarray resi=np.zeros([len(nr)], dtype=np.complex64)
+    resi[0:lmax+1] = np.real(nr[0:lmax+1])
+
+    for i in xrange(lmax+1,len(nr),2):
+        mi =  int(np.ceil(((2*lmax+1)-np.sqrt((2*lmax+1)*(2*lmax+1)-4*(i-lmax)+1))/2))
+        resi[i]=(-1)**mi/np.sqrt(2)*(nr[i]+1j*nr[i+1])
+        resi[i+1]=1/np.sqrt(2)*(nr[i]-1j*nr[i+1])
+    return resi
+
+
+cpdef np.ndarray[np.complex128_t, ndim=1] _buildLm(np.ndarray[np.float64_t,
+ndim=1] nr, np.int_t lmax):
     cdef np.int size = (len(nr)-lmax-1)/2+lmax+1
 
     cdef np.ndarray res=np.zeros([size], dtype=np.complex128)
     cdef np.ndarray temp=np.zeros([len(nr)], dtype=np.complex128)
     res[0:lmax+1] = nr[0:lmax+1]
 
-    temp = inverseRealify(nr, lmax)
+    temp = _inverseRealify(nr, lmax)
 
     for i in xrange(0,len(temp)-lmax-1,2):
         res[i/2+lmax+1] = temp[i+lmax+1]
     return res
 
-cpdef np.ndarray[np.complex128_t, ndim=1] inverseRealify(np.ndarray[np.float64_t, ndim=1] nr, np.int_t lmax):
+cpdef np.ndarray[np.complex128_t, ndim=1] _inverseRealify(np.ndarray[np.float64_t, ndim=1] nr, np.int_t lmax):
     cdef np.ndarray resi=np.zeros([len(nr)], dtype=np.complex128)
     resi[0:lmax+1] = np.real(nr[0:lmax+1])
 
     for i in xrange(lmax+1,len(nr),2):
-        # m calculation print i,(i-lmax)/2+1,(2*lmax+1)*(2*lmax+1)-4*(i-lmax)+1
         mi =  int(np.ceil(((2*lmax+1)-np.sqrt((2*lmax+1)*(2*lmax+1)-4*(i-lmax)+1))/2))
         resi[i]=(-1)**mi/np.sqrt(2)*(nr[i]+1j*nr[i+1])
         resi[i+1]=1/np.sqrt(2)*(nr[i]-1j*nr[i+1])
diff --git a/nifty/operators/fft_operator/transformations/lmgltransformation.py b/nifty/operators/fft_operator/transformations/lmgltransformation.py
index 99d2c70b1..6199671f4 100644
--- a/nifty/operators/fft_operator/transformations/lmgltransformation.py
+++ b/nifty/operators/fft_operator/transformations/lmgltransformation.py
@@ -118,25 +118,20 @@ class LMGLTransformation(Transformation):
             lmax = self.domain.lmax
             mmax = self.mmax
 
-            if self.domain.dtype == np.dtype('complex128'):
+            if inp.dtype >= np.dtype('complex64'):
                 inpReal = np.real(inp)
                 inpImag = np.imag(inp)
                 inpReal = ltf.buildLm(inpReal,lmax=lmax)
                 inpImag = ltf.buildLm(inpImag,lmax=lmax)
-                inpReal = gl.alm2map(inpReal, nlat=nlat, nlon=nlon,
+                inpReal = self.GlAlm2Map(inpReal, nlat=nlat, nlon=nlon,
                                  lmax=lmax, mmax=mmax, cl=False)
-                inpImag = gl.alm2map(inpImag, nlat=nlat, nlon=nlon,
+                inpImag = self.GlAlm2Map(inpImag, nlat=nlat, nlon=nlon,
                                  lmax=lmax, mmax=mmax, cl=False)
                 inp = inpReal+inpImag*(1j)
             else:
                 inp = ltf.buildLm(inp, lmax=lmax)
-
-                if self.domain.dtype == np.dtype('complex64'):
-                    inp = gl.alm2map_f(inp, nlat=nlat, nlon=nlon,
-                                       lmax=lmax, mmax=mmax, cl=False)
-                else:
-                    inp = gl.alm2map(inp, nlat=nlat, nlon=nlon,
-                                     lmax=lmax, mmax=mmax, cl=False)
+                inp = self.GlAlm2Map(inp, nlat=nlat, nlon=nlon,
+                                   lmax=lmax, mmax=mmax, cl=False)
 
             if slice_list == [slice(None, None)]:
                 return_val = inp
@@ -154,3 +149,9 @@ class LMGLTransformation(Transformation):
             return_val = return_val.astype(self.codomain.dtype, copy=False)
 
         return return_val
+
+    def GlAlm2Map(self, inp, **kwargs):
+        if inp.dtype == np.dtype('complex64'):
+            return gl.alm2map_f(inp, kwargs)
+        else:
+            return gl.alm2map(inp, kwargs)
-- 
GitLab