From a7d97e193937b1c00ae6c6281e1a907633b5612c Mon Sep 17 00:00:00 2001 From: Theo Steininger <theos@mpa-garching.mpg.de> Date: Fri, 21 Apr 2017 05:53:15 +0200 Subject: [PATCH] Simplified dtype checking in lm_transformation_factory.py --- .../lm_transformation_factory.py | 54 +++++++------------ 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/nifty/operators/fft_operator/transformations/lm_transformation_factory.py b/nifty/operators/fft_operator/transformations/lm_transformation_factory.py index ad1f038a7..fa2442381 100644 --- a/nifty/operators/fft_operator/transformations/lm_transformation_factory.py +++ b/nifty/operators/fft_operator/transformations/lm_transformation_factory.py @@ -1,47 +1,29 @@ import numpy as np -def buildLm(inp, **kwargs): - if inp.dtype == np.dtype('float32'): - return _buildLm_f(inp, **kwargs) - else: - return _buildLm(inp, **kwargs) -def buildIdx(inp, **kwargs): - if inp.dtype == np.dtype('complex64'): - return _buildIdx_f(inp, **kwargs) - else: - return _buildIdx(inp, **kwargs) +def buildLm(nr, lmax): + new_dtype = np.result_type(nr.dtype, np.complex64) -def _buildIdx_f(nr, lmax): - size = (lmax+1)*(lmax+1) + size = (len(nr)-lmax-1)/2+lmax+1 + res = np.zeros([size], dtype=new_dtype) + res[0:lmax+1] = nr[0:lmax+1] + res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2]) + return res - final=np.zeros([size], dtype=np.float32) - final[0:lmax+1] = nr[0:lmax+1].real - final[lmax+1::2] = np.sqrt(2)*nr[lmax+1:].real - final[lmax+2::2] = np.sqrt(2)*nr[lmax+1:].imag - return final -def _buildIdx(nr, lmax): - size = (lmax+1)*(lmax+1) +def buildIdx(nr, lmax): + if nr.dtype == np.dtype('complex64'): + new_dtype = np.float32 + elif nr.dtype == np.dtype('complex128'): + new_dtype = np.float64 + elif nr.dtype == np.dtype('complex256'): + new_dtype = np.float128 + else: + raise TypeError("dtype of nr not supported.") - final=np.zeros([size], dtype=np.float64) + size = (lmax+1)*(lmax+1) + final = np.zeros([size], dtype=new_dtype) final[0:lmax+1] = nr[0:lmax+1].real final[lmax+1::2] = np.sqrt(2)*nr[lmax+1:].real final[lmax+2::2] = np.sqrt(2)*nr[lmax+1:].imag return final - -def _buildLm_f(nr, lmax): - size = (len(nr)-lmax-1)/2+lmax+1 - - res=np.zeros([size], dtype=np.complex64) - res[0:lmax+1] = nr[0:lmax+1] - res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2]) - return res - -def _buildLm(nr, lmax): - size = (len(nr)-lmax-1)/2+lmax+1 - - res=np.zeros([size], dtype=np.complex128) - res[0:lmax+1] = nr[0:lmax+1] - res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2]) - return res -- GitLab