Commit a041bd3f authored by Ultima's avatar Ultima
Browse files

Implemented a configuration class and a dependency injector.

Resorted the packages and the import structure.
parent 85ba4336
...@@ -20,14 +20,17 @@ operators/* ...@@ -20,14 +20,17 @@ operators/*
!operators/nifty_probing.py !operators/nifty_probing.py
!operators/nifty_probing_old.py !operators/nifty_probing_old.py
dummys/*
!dummys/__init__.py
!dummys/gfft_dummy.py
!dummys/MPI_dummy.py
rg/* rg/*
!rg/__init__.py !rg/__init__.py
!rg/fft_rg.py !rg/nifty_fft.py
!rg/nifty_rg.py !rg/nifty_rg.py
!rg/nifty_power_conversion_rg.py !rg/nifty_power_conversion_rg.py
!rg/gfft_rg.py
!rg/powerspectrum.py
lm/* lm/*
!lm/__init__.py !lm/__init__.py
......
...@@ -22,15 +22,22 @@ ...@@ -22,15 +22,22 @@
from __future__ import division from __future__ import division
import matplotlib as mpl import matplotlib as mpl
#mpl.use('Agg') mpl.use('Agg')
import dummys
from keepers import about,\
global_dependency_injector,\
global_configuration
from nifty_about import about
from nifty_cmaps import ncmap from nifty_cmaps import ncmap
from nifty_core import space,\ from nifty_core import space,\
point_space,\ point_space,\
nested_space,\ nested_space,\
field field
from nifty_mpi_data import distributed_data_object, d2o_librarian from nifty_mpi_data import distributed_data_object, d2o_librarian
from nifty_power import * from nifty_power import *
from nifty_random import random from nifty_random import random
...@@ -64,13 +71,13 @@ try: ...@@ -64,13 +71,13 @@ try:
from nifty_paradict import gl_space_paradict from nifty_paradict import gl_space_paradict
except(ImportError): except(ImportError):
pass pass
try: try:
from lm import hp_space from lm import hp_space
from nifty_paradict import hp_space_paradict from nifty_paradict import hp_space_paradict
except(ImportError): except(ImportError):
pass pass
except(ImportError): except(ImportError):
pass pass
......
...@@ -55,7 +55,7 @@ class problem(object): ...@@ -55,7 +55,7 @@ class problem(object):
self.z = x_space self.z = x_space
## set conjugate space ## set conjugate space
self.k = self.z.get_codomain() self.k = self.z.get_codomain()
self.k.set_power_indices(**kwargs) #self.k.set_power_indices(**kwargs)
## set some power spectrum ## set some power spectrum
self.power = (lambda k: 42 / (k + 1) ** 3) self.power = (lambda k: 42 / (k + 1) ** 3)
...@@ -156,17 +156,24 @@ class problem(object): ...@@ -156,17 +156,24 @@ class problem(object):
while(iterating): while(iterating):
## reconstruct map ## reconstruct map
self.m = self.D(self.j, W=self.S, tol=1E-3, note=False) self.m = self.D(self.j, W=self.S, tol=1E-3, note=True)
if(self.m is None): if(self.m is None):
break break
print 'Reconstructed m'
## reconstruct power spectrum ## reconstruct power spectrum
tr_B1 = self.Sk.pseudo_tr(self.m) ## == Sk(m).pseudo_dot(m) tr_B1 = self.Sk.pseudo_tr(self.m) ## == Sk(m).pseudo_dot(m)
print 'Calculated trace B1'
print ('tr_b1', tr_B1)
tr_B2 = self.Sk.pseudo_tr(self.D, loop=True) tr_B2 = self.Sk.pseudo_tr(self.D, loop=True)
print 'Calculated trace B2'
numerator = 2 * q + tr_B1 + abs(delta) * tr_B2 ## non-bare(!) print ('tr_B2', tr_B2)
numerator = 2 * q + tr_B1 + tr_B2 * abs(delta) ## non-bare(!)
power = numerator / denominator power = numerator / denominator
print ('numerator', numerator)
print ('denominator', denominator)
print ('power', power)
print 'Calculated power'
power = np.clip(power, 0.1, np.max(power))
## check convergence ## check convergence
dtau = log(power / self.S.get_power(), base=self.S.get_power()) dtau = log(power / self.S.get_power(), base=self.S.get_power())
iterating = (np.max(np.abs(dtau)) > 2E-2) iterating = (np.max(np.abs(dtau)) > 2E-2)
...@@ -200,36 +207,39 @@ class problem(object): ...@@ -200,36 +207,39 @@ class problem(object):
##============================================================================= ##=============================================================================
##----------------------------------------------------------------------------- ##-----------------------------------------------------------------------------
#
if(__name__=="__main__"): if(__name__=="__main__"):
# pl.close("all") x = rg_space((128,))
p = problem(x, log = False)
## define signal space about.warnings.off()
x_space = rg_space(128) ## pl.close("all")
#
## setup problem # ## define signal space
p = problem(x_space, log=True) # x_space = rg_space(128)
## solve problem given some power spectrum #
p.solve() # ## setup problem
## solve problem # p = problem(x_space, log=True)
p.solve_critical() # ## solve problem given some power spectrum
# p.solve()
p.plot() # ## solve problem
# p.solve_critical()
## retrieve objects #
k_space = p.k # p.plot()
power = p.power #
S = p.S # ## retrieve objects
Sk = p.Sk # k_space = p.k
s = p.s # power = p.power
R = p.R # S = p.S
d_space = p.R.target # Sk = p.Sk
N = p.N # s = p.s
Nj = p.Nj # R = p.R
d = p.d # d_space = p.R.target
j = p.j # N = p.N
D = p.D # Nj = p.Nj
m = p.m # d = p.d
# j = p.j
# D = p.D
# m = p.m
##----------------------------------------------------------------------------- ##-----------------------------------------------------------------------------
...@@ -33,16 +33,16 @@ ...@@ -33,16 +33,16 @@
""" """
from __future__ import division from __future__ import division
from nifty import * # version 0.8.0 from nifty import * # version 0.8.0
about.warnings.on() about.warnings.off()
# some signal space; e.g., a two-dimensional regular grid # some signal space; e.g., a two-dimensional regular grid
x_space = rg_space([1280, 1280], datamodel = 'd2o') # define signal space x_space = rg_space([1280, 1280]) # define signal space
#x_space = rg_space(512) #x_space = rg_space(512)
#x_space = hp_space(32) #x_space = hp_space(32)
#x_space = gl_space(96) #x_space = gl_space(96)
k_space = x_space.get_codomain() # get conjugate space k_space = x_space.get_codomain() # get conjugate space
y_space = point_space(1280*1280, datamodel='d2o') y_space = point_space(1280*1280)
# some power spectrum # some power spectrum
power = (lambda k: 42 / (k + 1) ** 3) power = (lambda k: 42 / (k + 1) ** 3)
...@@ -58,7 +58,7 @@ N = diagonal_operator(d_space, diag=s.var(), bare=True) # define noise ...@@ -58,7 +58,7 @@ N = diagonal_operator(d_space, diag=s.var(), bare=True) # define noise
n = N.get_random_field(domain=d_space) # generate noise n = N.get_random_field(domain=d_space) # generate noise
d = R(s) + n # compute data d = R(s) #+ n # compute data
j = R.adjoint_times(N.inverse_times(d)) # define information source j = R.adjoint_times(N.inverse_times(d)) # define information source
D = propagator_operator(S=S, N=N, R=R) # define information propagator D = propagator_operator(S=S, N=N, R=R) # define information propagator
......
...@@ -4,12 +4,12 @@ import numpy as np ...@@ -4,12 +4,12 @@ import numpy as np
def MIN(): def MIN():
return np.min return np.min
def MAX(): def MAX():
return np.max return np.max
def SUM(): def SUM():
return np.sum return np.sum
...@@ -17,41 +17,41 @@ class _COMM_WORLD(): ...@@ -17,41 +17,41 @@ class _COMM_WORLD():
def __init__(self): def __init__(self):
self.rank = 0 self.rank = 0
self.size = 1 self.size = 1
def Get_rank(self): def Get_rank(self):
return self.rank return self.rank
def Get_size(self): def Get_size(self):
return self.size return self.size
def _scattergather_helper(self, sendbuf, recvbuf=None, **kwargs): def _scattergather_helper(self, sendbuf, recvbuf=None, **kwargs):
sendbuf = self._unwrapper(sendbuf) sendbuf = self._unwrapper(sendbuf)
recvbuf = self._unwrapper(recvbuf) recvbuf = self._unwrapper(recvbuf)
if recvbuf != None: if recvbuf != None:
recvbuf[:] = sendbuf recvbuf[:] = sendbuf
return recvbuf return recvbuf
else: else:
recvbuf = np.copy(sendbuf) recvbuf = np.copy(sendbuf)
return recvbuf return recvbuf
def bcast(self, sendbuf, *args, **kwargs): def bcast(self, sendbuf, *args, **kwargs):
return sendbuf return sendbuf
def Bcast(self, sendbuf, *args, **kwargs): def Bcast(self, sendbuf, *args, **kwargs):
return sendbuf return sendbuf
def scatter(self, sendbuf, *args, **kwargs): def scatter(self, sendbuf, *args, **kwargs):
return sendbuf[0] return sendbuf[0]
def Scatter(self, *args, **kwargs): def Scatter(self, *args, **kwargs):
return self._scattergather_helper(*args, **kwargs) return self._scattergather_helper(*args, **kwargs)
def Scatterv(self, *args, **kwargs): def Scatterv(self, *args, **kwargs):
return self._scattergather_helper(*args, **kwargs) return self._scattergather_helper(*args, **kwargs)
def gather(self, sendbuf, *args, **kwargs): def gather(self, sendbuf, *args, **kwargs):
return [sendbuf,] return [sendbuf,]
def Gather(self, *args, **kwargs): def Gather(self, *args, **kwargs):
return self._scattergather_helper(*args, **kwargs) return self._scattergather_helper(*args, **kwargs)
...@@ -60,30 +60,30 @@ class _COMM_WORLD(): ...@@ -60,30 +60,30 @@ class _COMM_WORLD():
def allgather(self, sendbuf, *args, **kwargs): def allgather(self, sendbuf, *args, **kwargs):
return [sendbuf,] return [sendbuf,]
def Allgather(self, *args, **kwargs): def Allgather(self, *args, **kwargs):
return self._scattergather_helper(*args, **kwargs) return self._scattergather_helper(*args, **kwargs)
def Allgatherv(self, *args, **kwargs): def Allgatherv(self, *args, **kwargs):
return self._scattergather_helper(*args, **kwargs) return self._scattergather_helper(*args, **kwargs)
def Allreduce(self, sendbuf, recvbuf, op, **kwargs): def Allreduce(self, sendbuf, recvbuf, op, **kwargs):
recvbuf[:] = op(sendbuf) recvbuf[:] = op(sendbuf)
return recvbuf return recvbuf
def allreduce(self, sendbuf, recvbuf, op, **kwargs): def allreduce(self, sendbuf, recvbuf, op, **kwargs):
recvbuf[:] = op(sendbuf) recvbuf[:] = op(sendbuf)
return recvbuf return recvbuf
def sendrecv(self, sendobj, **kwargs): def sendrecv(self, sendobj, **kwargs):
return sendobj return sendobj
def _unwrapper(self, x): def _unwrapper(self, x):
if isinstance(x, list): if isinstance(x, list):
return x[0] return x[0]
else: else:
return x return x
def Barrier(self): def Barrier(self):
pass pass
...@@ -91,7 +91,7 @@ class _datatype(): ...@@ -91,7 +91,7 @@ class _datatype():
def __init__(self, name): def __init__(self, name):
self.name = str(name) self.name = str(name)
BYTE = _datatype('MPI_BYTE') BYTE = _datatype('MPI_BYTE')
SHORT = _datatype('MPI_SHORT') SHORT = _datatype('MPI_SHORT')
UNSIGNED_SHORT = _datatype("MPI_UNSIGNED_SHORT") UNSIGNED_SHORT = _datatype("MPI_UNSIGNED_SHORT")
UNSIGNED_INT = _datatype("MPI_UNSIGNED_INT") UNSIGNED_INT = _datatype("MPI_UNSIGNED_INT")
...@@ -106,17 +106,16 @@ LONG_DOUBLE = _datatype("MPI_LONG_DOUBLE") ...@@ -106,17 +106,16 @@ LONG_DOUBLE = _datatype("MPI_LONG_DOUBLE")
COMPLEX = _datatype("MPI_COMPLEX") COMPLEX = _datatype("MPI_COMPLEX")
DOUBLE_COMPLEX = _datatype("MPI_DOUBLE_COMPLEX") DOUBLE_COMPLEX = _datatype("MPI_DOUBLE_COMPLEX")
COMM_WORLD = _COMM_WORLD() COMM_WORLD = _COMM_WORLD()
\ No newline at end of file
# -*- coding: utf-8 -*-
import gfft_dummy
import MPI_dummy
\ No newline at end of file
# -*- coding: utf-8 -*-
from nifty_about import *
from nifty_default_config import global_dependency_injector,\
global_configuration
\ No newline at end of file
...@@ -326,6 +326,8 @@ class notification(switch): ...@@ -326,6 +326,8 @@ class notification(switch):
##----------------------------------------------------------------------------- ##-----------------------------------------------------------------------------
class _about(object): ## nifty support class for global settings class _about(object): ## nifty support class for global settings
""" """
NIFTY support class for global settings. NIFTY support class for global settings.
...@@ -389,9 +391,9 @@ class _about(object): ## nifty support class for global settings ...@@ -389,9 +391,9 @@ class _about(object): ## nifty support class for global settings
self._version = str(__version__) self._version = str(__version__)
## switches and notifications ## switches and notifications
self._errors = notification(default=True, self._errors = notification(default=True,
ccode=notification._code) ccode=notification._code)
self.warnings = notification(default=True, self.warnings = notification(default=True,
ccode=notification._code) ccode=notification._code)
self.infos = notification(default=False, self.infos = notification(default=False,
ccode=notification._code) ccode=notification._code)
...@@ -486,7 +488,7 @@ class _about(object): ## nifty support class for global settings ...@@ -486,7 +488,7 @@ class _about(object): ## nifty support class for global settings
## set global instance ## set global instance
about = _about() about = _about()
about.load_config(force=False) #about.load_config(force=False)
about.infos.cprint("INFO: "+about.__repr__()) #about.infos.cprint("INFO: "+about.__repr__())
# -*- coding: utf-8 -*-
import ConfigParser
class variable(object):
def __init__(self, name, default_value_list, checker, genus='str'):
self.name = str(name)
if genus in ['str', 'int', 'float', 'boolean']:
self.genus = str(genus)
if not isinstance(default_value_list, list):
raise ValueError("The default_value_list argument must be a list!")
elif len(default_value_list) == 0:
default_value_list = [None]
self.default_value_list = default_value_list
if callable(checker):
self.checker = checker
else:
raise ValueError("The checker must be callable!")
self.set_value(None)
def set_value(self, value=None):
if value is None:
work_list = self.default_value_list
else:
work_list = [value]
success = False
for item in work_list:
if self.checker(item):
valid_item = item
success = True
break
if not success:
raise ValueError("No valid value supplied!" + str(work_list))
else:
self.value = valid_item
return self.value
def __call__(self):
return self.get_value()
def get_value(self):
return self.value
def get_name(self):
return self.name
def __repr__(self):
return "<nifty variable '" + str(self.name) + "': " + \
str(self.get_value()) + ">"
class configuration(object):
"""
configuration
init with file path to configuration file
get command -> dictionary like
set command -> parse input for sanity
reset -> reset to defaults
load from file
save to file
"""
def __init__(self, variables=[], path=None, path_section='DEFAULT'):
self.variable_dict = {}
map(self.register, variables)
self.path = None
self.set_path(path=path, path_section=path_section)
try:
self.load()
except ValueError:
pass
def __getitem__(self, key):
return self.get_variable(key)
def __setitem__(self, key, value):
return self.set_variable(key, value)
def register(self, variable):
self.variable_dict[variable.get_name()] = variable
def get_variable(self, name):
try:
return self.variable_dict[name].get_value()
except KeyError:
raise KeyError("The requested variable is not registered!")
def set_variable(self, name, value):
try:
return self.variable_dict[name].set_value(value)
except KeyError:
raise KeyError("The requested variable is not registered!")
def set_path(self, path=None, path_section=None):
if path is not None:
self.path = str(path)
if path_section is not None:
self.path_section = str(path_section)
def reset(self):
for key, item in self.variable_dict.items():
item.set_value(None)
def save(self, path=None, path_section=None):
if path is None:
if self.path is None:
raise ValueError("No init- or keyword-path available.")
else:
path = self.path
else:
path = path
if path_section is None:
path_section = self.path_section