From a7d97e193937b1c00ae6c6281e1a907633b5612c Mon Sep 17 00:00:00 2001
From: Theo Steininger <theos@mpa-garching.mpg.de>
Date: Fri, 21 Apr 2017 05:53:15 +0200
Subject: [PATCH] Simplified dtype checking in lm_transformation_factory.py

---
 .../lm_transformation_factory.py              | 54 +++++++------------
 1 file changed, 18 insertions(+), 36 deletions(-)

diff --git a/nifty/operators/fft_operator/transformations/lm_transformation_factory.py b/nifty/operators/fft_operator/transformations/lm_transformation_factory.py
index ad1f038a7..fa2442381 100644
--- a/nifty/operators/fft_operator/transformations/lm_transformation_factory.py
+++ b/nifty/operators/fft_operator/transformations/lm_transformation_factory.py
@@ -1,47 +1,29 @@
 import numpy as np
 
-def buildLm(inp, **kwargs):
-    if inp.dtype == np.dtype('float32'):
-        return _buildLm_f(inp, **kwargs)
-    else:
-        return _buildLm(inp, **kwargs)
 
-def buildIdx(inp, **kwargs):
-    if inp.dtype == np.dtype('complex64'):
-        return _buildIdx_f(inp, **kwargs)
-    else:
-        return _buildIdx(inp, **kwargs)
+def buildLm(nr, lmax):
+    new_dtype = np.result_type(nr.dtype, np.complex64)
 
-def _buildIdx_f(nr, lmax):
-    size = (lmax+1)*(lmax+1)
+    size = (len(nr)-lmax-1)/2+lmax+1
+    res = np.zeros([size], dtype=new_dtype)
+    res[0:lmax+1] = nr[0:lmax+1]
+    res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2])
+    return res
 
-    final=np.zeros([size], dtype=np.float32)
-    final[0:lmax+1] = nr[0:lmax+1].real
-    final[lmax+1::2] = np.sqrt(2)*nr[lmax+1:].real
-    final[lmax+2::2] = np.sqrt(2)*nr[lmax+1:].imag
-    return final
 
-def _buildIdx(nr, lmax):
-    size = (lmax+1)*(lmax+1)
+def buildIdx(nr, lmax):
+    if nr.dtype == np.dtype('complex64'):
+        new_dtype = np.float32
+    elif nr.dtype == np.dtype('complex128'):
+        new_dtype = np.float64
+    elif nr.dtype == np.dtype('complex256'):
+        new_dtype = np.float128
+    else:
+        raise TypeError("dtype of nr not supported.")
 
-    final=np.zeros([size], dtype=np.float64)
+    size = (lmax+1)*(lmax+1)
+    final = np.zeros([size], dtype=new_dtype)
     final[0:lmax+1] = nr[0:lmax+1].real
     final[lmax+1::2] = np.sqrt(2)*nr[lmax+1:].real
     final[lmax+2::2] = np.sqrt(2)*nr[lmax+1:].imag
     return final
-
-def _buildLm_f(nr, lmax):
-    size = (len(nr)-lmax-1)/2+lmax+1
-
-    res=np.zeros([size], dtype=np.complex64)
-    res[0:lmax+1] = nr[0:lmax+1]
-    res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2])
-    return res
-
-def _buildLm(nr, lmax):
-    size = (len(nr)-lmax-1)/2+lmax+1
-
-    res=np.zeros([size], dtype=np.complex128)
-    res[0:lmax+1] = nr[0:lmax+1]
-    res[lmax+1:] = np.sqrt(0.5)*(nr[lmax+1::2] + 1j*nr[lmax+2::2])
-    return res
-- 
GitLab