Commit 8f7d9547 authored by Marco Selig's avatar Marco Selig

nkdict_fast implemented.

parent 86d22470
...@@ -23,7 +23,7 @@ from __future__ import division ...@@ -23,7 +23,7 @@ from __future__ import division
from nifty_core import * from nifty_core import *
from nifty_cmaps import * from nifty_cmaps import *
from nifty_power import * from nifty_power import *
#from nifty_tools impoert * from nifty_tools import *
......
...@@ -2422,7 +2422,7 @@ class rg_space(space): ...@@ -2422,7 +2422,7 @@ class rg_space(space):
if(kindex is None): if(kindex is None):
## quick kindex ## quick kindex
if(self.fourier)and(not hasattr(self,"power_indices"))and(len(kwargs)==0): if(self.fourier)and(not hasattr(self,"power_indices"))and(len(kwargs)==0):
kindex = gp.nklength(gp.nkdict(self.para[:(np.size(self.para)-1)//2],self.vol,fourier=True)) kindex = gp.nklength(gp.nkdict_fast(self.para[:(np.size(self.para)-1)//2],self.vol,fourier=True))
## implicit kindex ## implicit kindex
else: else:
try: try:
......
...@@ -291,7 +291,7 @@ def _calc_inverse(tk,var,kindex,rho,b1,Amem): ## > computes the inverse Hessian ...@@ -291,7 +291,7 @@ def _calc_inverse(tk,var,kindex,rho,b1,Amem): ## > computes the inverse Hessian
## inversion ## inversion
return np.linalg.inv(T2+np.diag(b2,k=0)),b2,Amem return np.linalg.inv(T2+np.diag(b2,k=0)),b2,Amem
def infer_power(m,domain=None,Sk=None,D=None,pindex=None,pundex=None,kindex=None,rho=None,q=1E-42,alpha=1,perception=(1,0),smoothness=True,var=10,bare=True,**kwargs): def infer_power(m,domain=None,Sk=None,D=None,pindex=None,pundex=None,kindex=None,rho=None,q=1E-42,alpha=1,perception=(1,0),smoothness=True,var=10,force=False,bare=True,**kwargs):
""" """
Infers the power spectrum. Infers the power spectrum.
...@@ -338,6 +338,9 @@ def infer_power(m,domain=None,Sk=None,D=None,pindex=None,pundex=None,kindex=None ...@@ -338,6 +338,9 @@ def infer_power(m,domain=None,Sk=None,D=None,pindex=None,pundex=None,kindex=None
(default: True). (default: True).
var : {scalar, list, array}, *optional* var : {scalar, list, array}, *optional*
Variance of the assumed spectral smoothness prior (default: 10). Variance of the assumed spectral smoothness prior (default: 10).
force : bool, *optional*, *experimental*
Indicates whether smoothness is to be enforces or not
(default: False).
bare : bool, *optional* bare : bool, *optional*
Indicates whether the power spectrum entries returned are "bare" Indicates whether the power spectrum entries returned are "bare"
or not (mandatory for the correct incorporation of volume weights) or not (mandatory for the correct incorporation of volume weights)
...@@ -505,6 +508,7 @@ def infer_power(m,domain=None,Sk=None,D=None,pindex=None,pundex=None,kindex=None ...@@ -505,6 +508,7 @@ def infer_power(m,domain=None,Sk=None,D=None,pindex=None,pundex=None,kindex=None
numerator = weight_power(domain,numerator,power=-1,pindex=pindex,pundex=pundex) ## bare(!) numerator = weight_power(domain,numerator,power=-1,pindex=pindex,pundex=pundex) ## bare(!)
## smoothness prior ## smoothness prior
permill = 0
divergence = 1 divergence = 1
while(divergence): while(divergence):
pk = numerator/denominator1 ## bare(!) pk = numerator/denominator1 ## bare(!)
...@@ -524,7 +528,7 @@ def infer_power(m,domain=None,Sk=None,D=None,pindex=None,pundex=None,kindex=None ...@@ -524,7 +528,7 @@ def infer_power(m,domain=None,Sk=None,D=None,pindex=None,pundex=None,kindex=None
absdelta = np.abs(delta).max() absdelta = np.abs(delta).max()
tk += min(1,0.1/absdelta)*delta # adaptive step width tk += min(1,0.1/absdelta)*delta # adaptive step width
pk *= np.exp(min(1,0.1/absdelta)*delta) # adaptive step width pk *= np.exp(min(1,0.1/absdelta)*delta) # adaptive step width
var_ /= 1.1 # lowering the variance when converged var_ /= 1.1+permill # lowering the variance when converged
if(var_<var): if(var_<var):
if(breakinfo): # making sure there's one iteration with the correct variance if(breakinfo): # making sure there's one iteration with the correct variance
break break
...@@ -538,6 +542,14 @@ def infer_power(m,domain=None,Sk=None,D=None,pindex=None,pundex=None,kindex=None ...@@ -538,6 +542,14 @@ def infer_power(m,domain=None,Sk=None,D=None,pindex=None,pundex=None,kindex=None
break break
else: else:
divergence += 1 divergence += 1
if(force):
permill = 0.001
elif(force)and(var_/var_OLD>1.001):
permill = 0
pot = int(np.log10(var_))
var = int(1+var_*10**-pot)*10**pot
about.warnings.cprint("WARNING: smoothness variance increased ( var = "+str(var)+" ).")
break
else: else:
var_OLD = var_ var_OLD = var_
if(breakinfo): if(breakinfo):
......
...@@ -65,7 +65,7 @@ def draw_vector_nd(axes,dgrid,ps,symtype=0,fourier=False,zerocentered=False,kpac ...@@ -65,7 +65,7 @@ def draw_vector_nd(axes,dgrid,ps,symtype=0,fourier=False,zerocentered=False,kpac
""" """
if(kpack is None): if(kpack is None):
kdict = np.fft.fftshift(nkdict(axes,dgrid,fourier)) kdict = np.fft.fftshift(nkdict_fast(axes,dgrid,fourier))
klength = nklength(kdict) klength = nklength(kdict)
else: else:
kdict = kpack[1][np.fft.ifftshift(kpack[0],axes=shiftaxes(zerocentered,st_to_zero_mode=False))] kdict = kpack[1][np.fft.ifftshift(kpack[0],axes=shiftaxes(zerocentered,st_to_zero_mode=False))]
...@@ -164,7 +164,7 @@ def draw_vector_nd(axes,dgrid,ps,symtype=0,fourier=False,zerocentered=False,kpac ...@@ -164,7 +164,7 @@ def draw_vector_nd(axes,dgrid,ps,symtype=0,fourier=False,zerocentered=False,kpac
# foufield = field # foufield = field
# fieldabs = np.abs(foufield)**2 # fieldabs = np.abs(foufield)**2
# #
# kdict = nkdict(axes,dgrid,fourier) # kdict = nkdict_fast(axes,dgrid,fourier)
# klength = nklength(kdict) # klength = nklength(kdict)
# #
# ## power spectrum # ## power spectrum
...@@ -228,7 +228,7 @@ def calc_ps_fast(field,axes,dgrid,zerocentered=False,fourier=False,pindex=None,k ...@@ -228,7 +228,7 @@ def calc_ps_fast(field,axes,dgrid,zerocentered=False,fourier=False,pindex=None,k
if(rho is None): if(rho is None):
if(pindex is None): if(pindex is None):
## kdict ## kdict
kdict = nkdict(axes,dgrid,fourier) kdict = nkdict_fast(axes,dgrid,fourier)
## klength ## klength
if(kindex is None): if(kindex is None):
klength = nklength(kdict) klength = nklength(kdict)
...@@ -253,7 +253,7 @@ def calc_ps_fast(field,axes,dgrid,zerocentered=False,fourier=False,pindex=None,k ...@@ -253,7 +253,7 @@ def calc_ps_fast(field,axes,dgrid,zerocentered=False,fourier=False,pindex=None,k
rho[pindex[ii]] += 1 rho[pindex[ii]] += 1
elif(pindex is None): elif(pindex is None):
## kdict ## kdict
kdict = nkdict(axes,dgrid,fourier) kdict = nkdict_fast(axes,dgrid,fourier)
## klength ## klength
if(kindex is None): if(kindex is None):
klength = nklength(kdict) klength = nklength(kdict)
...@@ -317,9 +317,9 @@ def get_power_index(axes,dgrid,zerocentered,irred=False,fourier=True): ...@@ -317,9 +317,9 @@ def get_power_index(axes,dgrid,zerocentered,irred=False,fourier=True):
## kdict, klength ## kdict, klength
if(np.any(zerocentered==False)): if(np.any(zerocentered==False)):
kdict = np.fft.fftshift(nkdict(axes,dgrid,fourier),axes=shiftaxes(zerocentered,st_to_zero_mode=True)) kdict = np.fft.fftshift(nkdict_fast(axes,dgrid,fourier),axes=shiftaxes(zerocentered,st_to_zero_mode=True))
else: else:
kdict = nkdict(axes,dgrid,fourier) kdict = nkdict_fast(axes,dgrid,fourier)
klength = nklength(kdict) klength = nklength(kdict)
## output ## output
if(irred): if(irred):
...@@ -372,9 +372,9 @@ def get_power_indices(axes,dgrid,zerocentered,fourier=True): ...@@ -372,9 +372,9 @@ def get_power_indices(axes,dgrid,zerocentered,fourier=True):
## kdict, klength ## kdict, klength
if(np.any(zerocentered==False)): if(np.any(zerocentered==False)):
kdict = np.fft.fftshift(nkdict(axes,dgrid,fourier),axes=shiftaxes(zerocentered,st_to_zero_mode=True)) kdict = np.fft.fftshift(nkdict_fast(axes,dgrid,fourier),axes=shiftaxes(zerocentered,st_to_zero_mode=True))
else: else:
kdict = nkdict(axes,dgrid,fourier) kdict = nkdict_fast(axes,dgrid,fourier)
klength = nklength(kdict) klength = nklength(kdict)
## output ## output
ind = np.empty(axes,dtype=np.int) ind = np.empty(axes,dtype=np.int)
...@@ -587,13 +587,11 @@ def shiftaxes(zerocentered,st_to_zero_mode=False): ...@@ -587,13 +587,11 @@ def shiftaxes(zerocentered,st_to_zero_mode=False):
def nkdict(axes,dgrid,fourier=True): def nkdict(axes,dgrid,fourier=True):
""" """
Calculates an n-dimensional array with its entries being the lengths of Calculates an n-dimensional array with its entries being the lengths of
the k-vectors from the zero point of the Fourier grid. the k-vectors from the zero point of the Fourier grid.
""" """
if(fourier): if(fourier):
dk = dgrid dk = dgrid
else: else:
...@@ -605,6 +603,25 @@ def nkdict(axes,dgrid,fourier=True): ...@@ -605,6 +603,25 @@ def nkdict(axes,dgrid,fourier=True):
return kdict return kdict
def nkdict_fast(axes,dgrid,fourier=True):
"""
Calculates an n-dimensional array with its entries being the lengths of
the k-vectors from the zero point of the Fourier grid.
"""
if(fourier):
dk = dgrid
else:
dk = np.array([1/dgrid[i]/axes[i] for i in range(len(axes))])
temp_vecs = np.array(np.where(np.ones(axes)),dtype='float').reshape(np.append(len(axes),axes))
temp_vecs = np.rollaxis(temp_vecs,0,len(temp_vecs.shape))
temp_vecs -= axes//2
temp_vecs *= dk
temp_vecs *= temp_vecs
return np.sqrt(np.sum((temp_vecs),axis=-1))
def nklength(kdict): def nklength(kdict):
return np.sort(list(set(kdict.flatten()))) return np.sort(list(set(kdict.flatten())))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment