From db57b2dec57f697eb4481dac1ed799e72810f88f Mon Sep 17 00:00:00 2001
From: Marco Selig <mselig@ncg-02.MPA-Garching.MPG.DE>
Date: Wed, 11 Dec 2013 14:25:07 +0100
Subject: [PATCH] minor fixes; 'explicify' simplified.

---
 nifty_core.py     |  22 +++++-----
 nifty_explicit.py | 102 ++++++++++++++++++++++++++++++++++++++--------
 2 files changed, 96 insertions(+), 28 deletions(-)

diff --git a/nifty_core.py b/nifty_core.py
index 133683458..c644187bc 100644
--- a/nifty_core.py
+++ b/nifty_core.py
@@ -774,7 +774,7 @@ class random(object):
         """
         size = np.prod(shape,axis=0,dtype=np.int,out=None)
 
-        if(datatype in [np.complex64,np.complex128]):
+        if(issubclass(datatype,np.complexfloating)):
             x = np.array([1+0j,0+1j,-1+0j,0-1j],dtype=datatype)[np.random.randint(4,high=None,size=size)]
         else:
             x = 2*np.random.randint(2,high=None,size=size)-1
@@ -816,7 +816,7 @@ class random(object):
         """
         size = np.prod(shape,axis=0,dtype=np.int,out=None)
 
-        if(datatype in [np.complex64,np.complex128]):
+        if(issubclass(datatype,np.complexfloating)):
             x = np.empty(size,dtype=datatype,order='C')
             x.real = np.random.normal(loc=0,scale=np.sqrt(0.5),size=size)
             x.imag = np.random.normal(loc=0,scale=np.sqrt(0.5),size=size)
@@ -1941,7 +1941,7 @@ class point_space(space):
                 Number of degrees of freedom of the space.
         """
         ## dof ~ dim
-        if(self.datatype in [np.complex64,np.complex128]):
+        if(issubclass(self.datatype,np.complexfloating)):
             return 2*self.para[0]
         else:
             return self.para[0]
@@ -10609,7 +10609,7 @@ class probing(object):
 
         ## check codomain
         if(target is None):
-            target = domain.get_codomain()
+            target = self.domain.get_codomain()
         else:
             self.domain.check_codomain(target) ## a bit pointless
         self.target = target
@@ -11087,7 +11087,7 @@ class trace_probing(probing):
 
         ## check codomain
         if(target is None):
-            target = domain.get_codomain()
+            target = self.domain.get_codomain()
         else:
             self.domain.check_codomain(target) ## a bit pointless
         self.target = target
@@ -11171,7 +11171,7 @@ class trace_probing(probing):
         else:
             about.infos.cflush("\n")
 
-        if(self.domain.datatype in [np.complex64,np.complex128]):
+        if(issubclass(self.domain.datatype,np.complexfloating)):
             summa = np.real(summa)
 
         final = summa/num
@@ -11219,7 +11219,7 @@ class trace_probing(probing):
         ## define random seed
         seed = np.random.randint(10**8,high=None,size=self.nrun)
         ## define shared objects
-        if(self.domain.datatype in [np.complex64,np.complex128]):
+        if(issubclass(self.domain.datatype,np.complexfloating)):
             _sum = (mv('d',0,lock=True),mv('d',0,lock=True)) ## tuple(real,imag)
         else:
             _sum = mv('d',0,lock=True)
@@ -11246,7 +11246,7 @@ class trace_probing(probing):
             pool.join()
             raise Exception(about._errors.cstring("ERROR: unknown. NOTE: pool terminated.")) ## traceback by looping
         ## evaluate
-        if(self.domain.datatype in [np.complex64,np.complex128]):
+        if(issubclass(self.domain.datatype,np.complexfloating)):
             _sum = np.complex(_sum[0].value,_sum[1].value)
         else:
             _sum = _sum.value
@@ -11478,7 +11478,7 @@ class diagonal_probing(probing):
 
         ## check codomain
         if(target is None):
-            target = domain.get_codomain()
+            target = self.domain.get_codomain()
         else:
             self.domain.check_codomain(target) ## a bit pointless
         self.target = target
@@ -11646,7 +11646,7 @@ class diagonal_probing(probing):
         ## define random seed
         seed = np.random.randint(10**8,high=None,size=self.nrun)
         ## define shared objects
-        if(self.domain.datatype in [np.complex64,np.complex128]):
+        if(issubclass(self.domain.datatype,np.complexfloating)):
             _sum = (ma('d',np.zeros(self.domain.dim(split=False),dtype=np.float64,order='C'),lock=True),ma('d',np.zeros(self.domain.dim(split=False),dtype=np.float64,order='C'),lock=True)) ## tuple(real,imag)
         else:
             _sum = ma('d',np.zeros(self.domain.dim(split=False),dtype=np.float64,order='C'),lock=True)
@@ -11673,7 +11673,7 @@ class diagonal_probing(probing):
             pool.join()
             raise Exception(about._errors.cstring("ERROR: unknown. NOTE: pool terminated.")) ## traceback by looping
         ## evaluate
-        if(self.domain.datatype in [np.complex64,np.complex128]):
+        if(issubclass(self.domain.datatype,np.complexfloating)):
             _sum = (np.array(_sum[0][:])+np.array(_sum[1][:])*1j).reshape(self.domain.dim(split=True)) ## comlpex array
         else:
             _sum = np.array(_sum[:]).reshape(self.domain.dim(split=True))
diff --git a/nifty_explicit.py b/nifty_explicit.py
index d79daf906..7e8d4a400 100644
--- a/nifty_explicit.py
+++ b/nifty_explicit.py
@@ -107,8 +107,8 @@ class explicit_operator(operator):
             if(len(bare)!=2):
                 raise ValueError(about._errors.cstring("ERROR: invalid input."))
             else:
-                val = self._calc_weight_rows(val,-bool(bare[0]))
-                val = self._calc_weight_cols(val,-bool(bare[1]))
+                val = self._calc_weight_rows(val,power=-int(not bare[0]))
+                val = self._calc_weight_cols(val,power=-int(not bare[1]))
         elif(not bare):
             val = self._calc_weight_rows(val,-1)
         if(purelyreal):
@@ -156,6 +156,38 @@ class explicit_operator(operator):
 
     ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 
+    def cast_domain(self,newdomain):
+        """
+            TODO: documentation
+
+        """
+        if(not isinstance(newdomain,space)):
+            raise TypeError(about._errors.cstring("ERROR: invalid input."))
+        elif(newdomain.dim(split=False)!=self.domain.dim(split=False)):
+            raise ValueError(about._errors.cstring("ERROR: dimension mismatch ( "+str(newdomain.dim(split=False))+" <> "+str(self.domain.dim(split=False))+" )."))
+        self.domain = newdomain
+
+    def cast_target(self,newtarget):
+        """
+            TODO: documentation
+
+        """
+        if(not isinstance(newtarget,space)):
+            raise TypeError(about._errors.cstring("ERROR: invalid input."))
+        elif(newtarget.dim(split=False)!=self.target.dim(split=False)):
+            raise ValueError(about._errors.cstring("ERROR: dimension mismatch ( "+str(newtarget.dim(split=False))+" <> "+str(self.target.dim(split=False))+" )."))
+        self.target = newtarget
+
+    def cast_spaces(self,newdomain,newtarget):
+        """
+            TODO: documentation
+
+        """
+        self.cast_domain(newdomain)
+        self.cast_target(newtarget)
+
+    ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
+
     def set_matrix(self,newmatrix,bare=True,sym=None,uni=None):
         """
             TODO: documentation
@@ -182,8 +214,8 @@ class explicit_operator(operator):
             if(len(bare)!=2):
                 raise ValueError(about._errors.cstring("ERROR: invalid input."))
             else:
-                val = self._calc_weight_rows(val,-bool(bare[0]))
-                val = self._calc_weight_cols(val,-bool(bare[1]))
+                val = self._calc_weight_rows(val,power=-int(not bare[0]))
+                val = self._calc_weight_cols(val,power=-int(not bare[1]))
         elif(not bare):
             val = self._calc_weight_rows(val,-1)
         if(purelyreal):
@@ -209,9 +241,9 @@ class explicit_operator(operator):
             if(len(bare)!=2):
                 raise ValueError(about._errors.cstring("ERROR: invalid input."))
             else:
-                return self.weight(rowpower=bool(bare[0]),colpower=bool(bare[1]),overwrite=False)
+                return self.weight(rowpower=int(not bare[0]),colpower=int(not bare[1]),overwrite=False)
         elif(not bare):
-            return self.weight(rowpower=bool(bare),colpower=0,overwrite=False)
+            return self.weight(rowpower=int(not bare),colpower=0,overwrite=False)
 
     ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 
@@ -222,10 +254,16 @@ class explicit_operator(operator):
         return self.target.calc_weight(x,power=power).flatten(order='C')
 
     def _calc_weight_rows(self,X,power=1): ## > weight all rows
-        return np.apply_along_axis(self._calc_weight_row,1,X,power)
+        if(np.any(np.iscomplex(X)))and(not issubclass(self.domain.datatype,np.complexfloating)):
+            return (np.apply_along_axis(self._calc_weight_row,1,np.real(X),power)+np.apply_along_axis(self._calc_weight_row,1,np.imag(X),power)*1j)
+        else:
+            return np.apply_along_axis(self._calc_weight_row,1,X,power)
 
     def _calc_weight_cols(self,X,power=1): ## > weight all columns
-        return np.apply_along_axis(self._calc_weight_col,0,X,power)
+        if(np.any(np.iscomplex(X)))and(not issubclass(self.target.datatype,np.complexfloating)):
+            return (np.apply_along_axis(self._calc_weight_col,0,np.real(X),power)+np.apply_along_axis(self._calc_weight_col,0,np.imag(X),power)*1j)
+        else:
+            return np.apply_along_axis(self._calc_weight_col,0,X,power)
 
     def weight(self,rowpower=0,colpower=0,overwrite=False):
         """
@@ -1448,7 +1486,7 @@ class explicit_probing(probing):
 #            Keyword arguments passed to `function` in each call.
 #
 #    """
-    def __init__(self,op=None,function=None,domain=None,codomain=None,target=None,ncpu=2,nper=None,**quargs):
+    def __init__(self,op=None,function=None,domain=None,codomain=None,target=None,ncpu=2,nper=1,**quargs):
         """
             TODO: documentation
 
@@ -1489,7 +1527,7 @@ class explicit_probing(probing):
 #            If on the other hand `nper=1`, then for each evaluation a worker will
 #            be created. In this case all cpus will work until nrun probes have
 #            been evaluated.
-#            It is recommended to leave `nper` as the default value. (default: 8)
+#            It is recommended to leave `nper` as the default value. (default: 1)
 #
 #        """
         if(op is None):
@@ -1546,12 +1584,12 @@ class explicit_probing(probing):
         self.domain = domain
         self.codomain = codomain
 
+        ## check target
         if(target is None):
-            self.target = domain.get_codomain()
+            target = self.domain.get_codomain()
         else:
-            ## check codomain
             self.domain.check_codomain(target) ## a bit pointless
-            self.target = target
+        self.target = target
 
         ## check shape
         if(self.domain.dim(split=False)*self.codomain.dim(split=False)>1048576):
@@ -1717,7 +1755,7 @@ class explicit_probing(probing):
         ## define weighted canonical base
         base = self.domain.calc_weight(self.domain.enforce_values(1,extend=True),power=-1).flatten(order='C')
         ## define shared objects
-        if(self.codomain.datatype in [np.complex64,np.complex128]):
+        if(issubclass(self.codomain.datatype,np.complexfloating)):
             _mat = (ma('d',np.empty(self.nrun*self.codomain.dim(split=False),dtype=np.float64,order='C'),lock=True),ma('d',np.empty(self.nrun*self.codomain.dim(split=False),dtype=np.float64,order='C'),lock=True)) ## tuple(real,imag)
         else:
             _mat = ma('d',np.empty(self.nrun*self.codomain.dim(split=False),dtype=np.float64,order='C'),lock=True)
@@ -1740,7 +1778,7 @@ class explicit_probing(probing):
             pool.join()
             raise Exception(about._errors.cstring("ERROR: unknown. NOTE: pool terminated.")) ## traceback by looping
         ## evaluate
-        if(self.domain.datatype in [np.complex64,np.complex128]):
+        if(issubclass(self.codomain.datatype,np.complexfloating)):
             _mat = (np.array(_mat[0][:])+np.array(_mat[1][:])*1j).reshape((self.nrun,self.codomain.dim(split=False))) ## comlpex array
         else:
             _mat = np.array(_mat[:]).reshape((self.nrun,self.codomain.dim(split=False)))
@@ -1805,12 +1843,42 @@ class explicit_probing(probing):
 
 ##-----------------------------------------------------------------------------
 
-def explicify(operator,loop=False,**kwargs):
+def explicify(operator,newdomain=None,newtarget=None,ncpu=2,nper=1,loop=False,**quargs):
     """
         TODO: documentation
 
     """
-    return explicit_probing(op=operator,**kwargs)(loop=loop)
+    return explicit_probing(op=operator,function=operator.times,domain=newdomain,codomain=newtarget,target=operator.domain,ncpu=ncpu,nper=nper,**quargs)(loop=loop)
 
 ##-----------------------------------------------------------------------------
 
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-- 
GitLab