Commit a041bd3f authored by Ultima's avatar Ultima
Browse files

Implemented a configuration class and a dependency injector.

Resorted the packages and the import structure.
parent 85ba4336
......@@ -20,14 +20,17 @@ operators/*
!operators/nifty_probing.py
!operators/nifty_probing_old.py
dummys/*
!dummys/__init__.py
!dummys/gfft_dummy.py
!dummys/MPI_dummy.py
rg/*
!rg/__init__.py
!rg/fft_rg.py
!rg/nifty_fft.py
!rg/nifty_rg.py
!rg/nifty_power_conversion_rg.py
!rg/gfft_rg.py
!rg/powerspectrum.py
lm/*
!lm/__init__.py
......
......@@ -22,9 +22,16 @@
from __future__ import division
import matplotlib as mpl
#mpl.use('Agg')
mpl.use('Agg')
import dummys
from keepers import about,\
global_dependency_injector,\
global_configuration
from nifty_about import about
from nifty_cmaps import ncmap
from nifty_core import space,\
point_space,\
......
......@@ -55,7 +55,7 @@ class problem(object):
self.z = x_space
## set conjugate space
self.k = self.z.get_codomain()
self.k.set_power_indices(**kwargs)
#self.k.set_power_indices(**kwargs)
## set some power spectrum
self.power = (lambda k: 42 / (k + 1) ** 3)
......@@ -156,17 +156,24 @@ class problem(object):
while(iterating):
## reconstruct map
self.m = self.D(self.j, W=self.S, tol=1E-3, note=False)
self.m = self.D(self.j, W=self.S, tol=1E-3, note=True)
if(self.m is None):
break
print 'Reconstructed m'
## reconstruct power spectrum
tr_B1 = self.Sk.pseudo_tr(self.m) ## == Sk(m).pseudo_dot(m)
print 'Calculated trace B1'
print ('tr_b1', tr_B1)
tr_B2 = self.Sk.pseudo_tr(self.D, loop=True)
numerator = 2 * q + tr_B1 + abs(delta) * tr_B2 ## non-bare(!)
print 'Calculated trace B2'
print ('tr_B2', tr_B2)
numerator = 2 * q + tr_B1 + tr_B2 * abs(delta) ## non-bare(!)
power = numerator / denominator
print ('numerator', numerator)
print ('denominator', denominator)
print ('power', power)
print 'Calculated power'
power = np.clip(power, 0.1, np.max(power))
## check convergence
dtau = log(power / self.S.get_power(), base=self.S.get_power())
iterating = (np.max(np.abs(dtau)) > 2E-2)
......@@ -200,36 +207,39 @@ class problem(object):
##=============================================================================
##-----------------------------------------------------------------------------
#
if(__name__=="__main__"):
# pl.close("all")
## define signal space
x_space = rg_space(128)
## setup problem
p = problem(x_space, log=True)
## solve problem given some power spectrum
p.solve()
## solve problem
p.solve_critical()
p.plot()
## retrieve objects
k_space = p.k
power = p.power
S = p.S
Sk = p.Sk
s = p.s
R = p.R
d_space = p.R.target
N = p.N
Nj = p.Nj
d = p.d
j = p.j
D = p.D
m = p.m
x = rg_space((128,))
p = problem(x, log = False)
about.warnings.off()
## pl.close("all")
#
# ## define signal space
# x_space = rg_space(128)
#
# ## setup problem
# p = problem(x_space, log=True)
# ## solve problem given some power spectrum
# p.solve()
# ## solve problem
# p.solve_critical()
#
# p.plot()
#
# ## retrieve objects
# k_space = p.k
# power = p.power
# S = p.S
# Sk = p.Sk
# s = p.s
# R = p.R
# d_space = p.R.target
# N = p.N
# Nj = p.Nj
# d = p.d
# j = p.j
# D = p.D
# m = p.m
##-----------------------------------------------------------------------------
......@@ -33,16 +33,16 @@
"""
from __future__ import division
from nifty import * # version 0.8.0
about.warnings.on()
about.warnings.off()
# some signal space; e.g., a two-dimensional regular grid
x_space = rg_space([1280, 1280], datamodel = 'd2o') # define signal space
x_space = rg_space([1280, 1280]) # define signal space
#x_space = rg_space(512)
#x_space = hp_space(32)
#x_space = gl_space(96)
k_space = x_space.get_codomain() # get conjugate space
y_space = point_space(1280*1280, datamodel='d2o')
y_space = point_space(1280*1280)
# some power spectrum
power = (lambda k: 42 / (k + 1) ** 3)
......@@ -58,7 +58,7 @@ N = diagonal_operator(d_space, diag=s.var(), bare=True) # define noise
n = N.get_random_field(domain=d_space) # generate noise
d = R(s) + n # compute data
d = R(s) #+ n # compute data
j = R.adjoint_times(N.inverse_times(d)) # define information source
D = propagator_operator(S=S, N=N, R=R) # define information propagator
......
......@@ -119,4 +119,3 @@ COMM_WORLD = _COMM_WORLD()
\ No newline at end of file
# -*- coding: utf-8 -*-
import gfft_dummy
import MPI_dummy
\ No newline at end of file
# -*- coding: utf-8 -*-
from nifty_about import *
from nifty_default_config import global_dependency_injector,\
global_configuration
\ No newline at end of file
......@@ -326,6 +326,8 @@ class notification(switch):
##-----------------------------------------------------------------------------
class _about(object): ## nifty support class for global settings
"""
NIFTY support class for global settings.
......@@ -486,7 +488,7 @@ class _about(object): ## nifty support class for global settings
## set global instance
about = _about()
about.load_config(force=False)
about.infos.cprint("INFO: "+about.__repr__())
#about.load_config(force=False)
#about.infos.cprint("INFO: "+about.__repr__())
# -*- coding: utf-8 -*-
import ConfigParser
class variable(object):
def __init__(self, name, default_value_list, checker, genus='str'):
self.name = str(name)
if genus in ['str', 'int', 'float', 'boolean']:
self.genus = str(genus)
if not isinstance(default_value_list, list):
raise ValueError("The default_value_list argument must be a list!")
elif len(default_value_list) == 0:
default_value_list = [None]
self.default_value_list = default_value_list
if callable(checker):
self.checker = checker
else:
raise ValueError("The checker must be callable!")
self.set_value(None)
def set_value(self, value=None):
if value is None:
work_list = self.default_value_list
else:
work_list = [value]
success = False
for item in work_list:
if self.checker(item):
valid_item = item
success = True
break
if not success:
raise ValueError("No valid value supplied!" + str(work_list))
else:
self.value = valid_item
return self.value
def __call__(self):
return self.get_value()
def get_value(self):
return self.value
def get_name(self):
return self.name
def __repr__(self):
return "<nifty variable '" + str(self.name) + "': " + \
str(self.get_value()) + ">"
class configuration(object):
"""
configuration
init with file path to configuration file
get command -> dictionary like
set command -> parse input for sanity
reset -> reset to defaults
load from file
save to file
"""
def __init__(self, variables=[], path=None, path_section='DEFAULT'):
self.variable_dict = {}
map(self.register, variables)
self.path = None
self.set_path(path=path, path_section=path_section)
try:
self.load()
except ValueError:
pass
def __getitem__(self, key):
return self.get_variable(key)
def __setitem__(self, key, value):
return self.set_variable(key, value)
def register(self, variable):
self.variable_dict[variable.get_name()] = variable
def get_variable(self, name):
try:
return self.variable_dict[name].get_value()
except KeyError:
raise KeyError("The requested variable is not registered!")
def set_variable(self, name, value):
try:
return self.variable_dict[name].set_value(value)
except KeyError:
raise KeyError("The requested variable is not registered!")
def set_path(self, path=None, path_section=None):
if path is not None:
self.path = str(path)
if path_section is not None:
self.path_section = str(path_section)
def reset(self):
for key, item in self.variable_dict.items():
item.set_value(None)
def save(self, path=None, path_section=None):
if path is None:
if self.path is None:
raise ValueError("No init- or keyword-path available.")
else:
path = self.path
else:
path = path
if path_section is None:
path_section = self.path_section
config_parser = ConfigParser.ConfigParser()
try:
config_parser.add_section(path_section)
except ValueError:
pass
for item in self.variable_dict:
config_parser.set(path_section,
item,
str(self[item]))
config_file = open(path, 'wb')
config_parser.write(config_file)
def load(self, path=None, path_section=None):
if path is None:
if self.path is None:
raise ValueError("No init- or keyword-path available.")
else:
path = self.path
else:
path = path
if path_section is None:
path_section = self.path_section
config_parser = ConfigParser.ConfigParser()
config_parser.read(path)
for key, item in self.variable_dict.items():
if item.genus == 'str':
temp_value = config_parser.get(path_section, item.name)
elif item.genus == 'int':
temp_value = config_parser.getint(path_section, item.name)
elif item.genus == 'float':
temp_value = config_parser.getfloat(path_section, item.name)
elif item.genus == 'boolean':
temp_value = config_parser.getboolean(path_section, item.name)
else:
raise ValueError("Unknown variable genus.")
item.set_value(temp_value)
def __repr__(self):
return "<nifty configuration> \n" + self.variable_dict.__repr__()
# -*- coding: utf-8 -*-
import os
from nifty_dependency_injector import dependency_injector
from nifty_configuration import variable,\
configuration
global_dependency_injector = dependency_injector(
['h5py',
('mpi4py.MPI', 'MPI'),
('nifty.dummys.MPI_dummy', 'MPI_dummy'),
'pyfftw',
'gfft',
('nifty.dummys.gfft_dummy', 'gfft_dummy'),
'healpy',
'libsharp_wrapper_gl'])
variable_fft_module = variable('fft_module',
['pyfftw', 'gfft', 'gfft_fallback'],
lambda z: z in global_dependency_injector)
variable_lm2gl = variable('lm2gl',
[True, False],
lambda z: z is True or z is False,
'boolean')
variable_verbosity = variable('verbosity',
[1],
lambda z: z == abs(int(z)),
'int')
variable_mpi_module = variable('mpi_module',
['MPI', 'MPI_dummy'],
lambda z: z in global_dependency_injector)
variable_default_distribution_strategy = variable(
'default_distribution_strategy',
['fftw', 'equal', 'not'],
lambda z: (('pyfftw' in global_dependency_injector)
if (z == 'pyfftw') else True)
)
global_configuration = configuration(
[variable_fft_module,
variable_lm2gl,
variable_verbosity,
variable_mpi_module,
variable_default_distribution_strategy
],
path=os.path.expanduser('~') + "/.nifty/global_config")
# -*- coding: utf-8 -*-
import imp
import sys
class dependency_injector(object):
def __init__(self, modules=[]):
self.registry = {}
map(self.register, modules)
def get(self, x):
return self.registry.get(x)
def __getitem__(self, x):
return self.registry.__getitem__(x)
def __contains__(self, x):
return self.registry.__contains__(x)
def __iter__(self):
return self.registry.__iter__()
def __getattr__(self, x):
return self.registry.__getattribute__(x)
def register(self, module_name):
if isinstance(module_name, tuple):
module_name, key_name = (str(module_name[0]), str(module_name[1]))
else:
module_name = str(module_name)
key_name = module_name
try:
loaded_module = sys.modules[module_name]
self.registry[key_name] = loaded_module
except KeyError:
try:
fp, pathname, description = imp.find_module(module_name)
loaded_module = \
imp.load_module(module_name, fp, pathname, description)
self.registry[key_name] = loaded_module
except ImportError:
pass
finally:
# Since we may exit via an exception, close fp explicitly.
try:
fp.close()
except (UnboundLocalError, AttributeError):
pass
......@@ -20,7 +20,7 @@
## along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import division
from nifty.nifty_about import about
from nifty.keepers import about
from distutils.version import LooseVersion as lv
......
......@@ -38,7 +38,7 @@ import numpy as np
import pylab as pl
from matplotlib.colors import LogNorm as ln
from matplotlib.ticker import LogFormatter as lf
from nifty.nifty_about import about
from nifty.keepers import about
from nifty.nifty_core import pi, \
space, \
point_space, \
......@@ -179,6 +179,7 @@ class lm_space(point_space):
self.datamodel = datamodel
self.discrete = True
self.harmonic = True
self.vol = np.real(np.array([1],dtype=self.datatype))
@property
......@@ -532,6 +533,9 @@ class lm_space(point_space):
if(not isinstance(codomain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
if self.datamodel is not codomain.datamodel:
return False
if(self==codomain):
return True
......@@ -1063,6 +1067,7 @@ class gl_space(point_space):
self.datamodel = datamodel
self.discrete = False
self.harmonic = False
self.vol = gl.vol(self.paradict['nlat'],nlon=self.paradict['nlon']).astype(self.datatype)
......@@ -1313,6 +1318,9 @@ class gl_space(point_space):
if(not isinstance(codomain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
if self.datamodel is not codomain.datamodel:
return False
if(self==codomain):
return True
......@@ -1754,6 +1762,7 @@ class hp_space(point_space):
self.datamodel = datamodel
self.discrete = False
self.harmonic = False
self.vol = np.array([4*pi/(12*self.paradict['nside']**2)],dtype=self.datatype)
@property
......@@ -1987,6 +1996,9 @@ class hp_space(point_space):
if(not isinstance(codomain,space)):
raise TypeError(about._errors.cstring("ERROR: invalid input."))
if self.datamodel is not codomain.datamodel:
return False
if(self==codomain):
return True
......
......@@ -21,7 +21,7 @@
#from nifty import *
import numpy as np
from nifty.nifty_about import about
from nifty.keepers import about
from nifty.nifty_core import pi, field
from nifty.nifty_simple_math import sqrt, exp, log
......
This diff is collapsed.
......@@ -22,39 +22,49 @@
import numpy as np
from nifty_about import about
from weakref import WeakValueDictionary as weakdict
# initialize the 'FOUND-packages'-dictionary
FOUND = {}
try:
from mpi4py import MPI
FOUND['MPI'] = True
except(ImportError):
import mpi_dummy as MPI
FOUND['MPI'] = False
from keepers import about,\
global_configuration as gc,\
global_dependency_injector as gdi
try:
import pyfftw
FOUND['pyfftw'] = True
except(ImportError):
FOUND['pyfftw'] = False
MPI = gdi[gc['mpi_module']]
h5py = gdi.get('h5py')
pyfftw = gdi.get('pyfftw')
try:
import h5py
FOUND['h5py'] = True
FOUND['h5py_parallel'] = h5py.get_config().mpi
except(ImportError):
FOUND['h5py'] = False
FOUND['h5py_parallel'] = False
ALL_DISTRIBUTION_STRATEGIES = ['not', 'equal', 'fftw', 'freeform']
GLOBAL_DISTRIBUTION_STRATEGIES = ['not', 'equal', 'fftw']
LOCAL_DISTRIBUTION_STRATEGIES = ['freeform']
HDF5_DISTRIBUTION_STRATEGIES = ['equal', 'fftw']