diff --git a/nifty/operators/fft_operator/transformations/lm_transformation_factory.py b/nifty/operators/fft_operator/transformations/lm_transformation_factory.py index ad1f038a7de822c56802fb798235607409a90465..fa2442381f47fb7491db1b0b07ea30dd6127b6e2 100644 --- a/nifty/operators/fft_operator/transformations/lm_transformation_factory.py +++ b/nifty/operators/fft_operator/transformations/lm_transformation_factory.py @@ -1,47 +1,29 @@ import numpy as np -def buildLm(inp, **kwargs): - if inp.dtype == np.dtype('float32'): - return _buildLm_f(inp, **kwargs) - else: - return _buildLm(inp, **kwargs) -def buildIdx(inp, **kwargs): - if inp.dtype == np.dtype('complex64'): - return _buildIdx_f(inp, **kwargs) - else: - return _buildIdx(inp, **kwargs) +def buildLm(nr, lmax): + new_dtype = np.result_type(nr.dtype, np.complex64) -def _buildIdx_f(nr, lmax): - size = (lmax+1)*(lmax+1) + size = (len(nr)-lmax-1)/2+lmax+1 + res = np.zeros([size], dtype=new_dtype) + res[0:lmax+1] = nr[0:lmax+1] + res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2]) + return res - final=np.zeros([size], dtype=np.float32) - final[0:lmax+1] = nr[0:lmax+1].real - final[lmax+1::2] = np.sqrt(2)*nr[lmax+1:].real - final[lmax+2::2] = np.sqrt(2)*nr[lmax+1:].imag - return final -def _buildIdx(nr, lmax): - size = (lmax+1)*(lmax+1) +def buildIdx(nr, lmax): + if nr.dtype == np.dtype('complex64'): + new_dtype = np.float32 + elif nr.dtype == np.dtype('complex128'): + new_dtype = np.float64 + elif nr.dtype == np.dtype('complex256'): + new_dtype = np.float128 + else: + raise TypeError("dtype of nr not supported.") - final=np.zeros([size], dtype=np.float64) + size = (lmax+1)*(lmax+1) + final = np.zeros([size], dtype=new_dtype) final[0:lmax+1] = nr[0:lmax+1].real final[lmax+1::2] = np.sqrt(2)*nr[lmax+1:].real final[lmax+2::2] = np.sqrt(2)*nr[lmax+1:].imag return final - -def _buildLm_f(nr, lmax): - size = (len(nr)-lmax-1)/2+lmax+1 - - res=np.zeros([size], dtype=np.complex64) - res[0:lmax+1] = nr[0:lmax+1] - res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2]) - return res - -def _buildLm(nr, lmax): - size = (len(nr)-lmax-1)/2+lmax+1 - - res=np.zeros([size], dtype=np.complex128) - res[0:lmax+1] = nr[0:lmax+1] - res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2]) - return res