Skip to content
Snippets Groups Projects
Commit a7d97e19 authored by Theo Steininger's avatar Theo Steininger
Browse files

Simplified dtype checking in lm_transformation_factory.py

parent 71ca3eba
No related branches found
No related tags found
2 merge requests!69Master,!65Martin's monster merge part 2/N: partial de-cythonization
import numpy as np
def buildLm(inp, **kwargs):
if inp.dtype == np.dtype('float32'):
return _buildLm_f(inp, **kwargs)
else:
return _buildLm(inp, **kwargs)
def buildIdx(inp, **kwargs):
if inp.dtype == np.dtype('complex64'):
return _buildIdx_f(inp, **kwargs)
else:
return _buildIdx(inp, **kwargs)
def buildLm(nr, lmax):
new_dtype = np.result_type(nr.dtype, np.complex64)
def _buildIdx_f(nr, lmax):
size = (lmax+1)*(lmax+1)
size = (len(nr)-lmax-1)/2+lmax+1
res = np.zeros([size], dtype=new_dtype)
res[0:lmax+1] = nr[0:lmax+1]
res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2])
return res
final=np.zeros([size], dtype=np.float32)
final[0:lmax+1] = nr[0:lmax+1].real
final[lmax+1::2] = np.sqrt(2)*nr[lmax+1:].real
final[lmax+2::2] = np.sqrt(2)*nr[lmax+1:].imag
return final
def _buildIdx(nr, lmax):
size = (lmax+1)*(lmax+1)
def buildIdx(nr, lmax):
if nr.dtype == np.dtype('complex64'):
new_dtype = np.float32
elif nr.dtype == np.dtype('complex128'):
new_dtype = np.float64
elif nr.dtype == np.dtype('complex256'):
new_dtype = np.float128
else:
raise TypeError("dtype of nr not supported.")
final=np.zeros([size], dtype=np.float64)
size = (lmax+1)*(lmax+1)
final = np.zeros([size], dtype=new_dtype)
final[0:lmax+1] = nr[0:lmax+1].real
final[lmax+1::2] = np.sqrt(2)*nr[lmax+1:].real
final[lmax+2::2] = np.sqrt(2)*nr[lmax+1:].imag
return final
def _buildLm_f(nr, lmax):
size = (len(nr)-lmax-1)/2+lmax+1
res=np.zeros([size], dtype=np.complex64)
res[0:lmax+1] = nr[0:lmax+1]
res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2])
return res
def _buildLm(nr, lmax):
size = (len(nr)-lmax-1)/2+lmax+1
res=np.zeros([size], dtype=np.complex128)
res[0:lmax+1] = nr[0:lmax+1]
res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2])
return res
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment