Commit 6343088d authored by theos's avatar theos
Browse files

dependency_injector now is able to apply a check function to a loaded module before registering.

Added a check for pyfftw.FFTW_MPI
parent 86674115
Pipeline #1335 skipped
......@@ -6,16 +6,17 @@ from nifty_dependency_injector import dependency_injector
from nifty_configuration import variable,\
configuration
global_dependency_injector = dependency_injector(
['h5py',
('mpi4py.MPI', 'MPI'),
('nifty.dummys.MPI_dummy', 'MPI_dummy'),
'pyfftw',
'gfft',
('nifty.dummys.gfft_dummy', 'gfft_dummy'),
'healpy',
'libsharp_wrapper_gl'])
global_dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
variable_fft_module = variable('fft_module',
['pyfftw', 'gfft', 'gfft_dummy'],
......@@ -57,7 +58,7 @@ variable_default_distribution_strategy = variable(
'default_distribution_strategy',
['fftw', 'equal', 'not'],
lambda z: (('pyfftw' in global_dependency_injector)
if (z == 'pyfftw') else True)
if (z == 'fftw') else True)
)
variable_d2o_init_checks = variable('d2o_init_checks',
......
......@@ -22,34 +22,29 @@ class dependency_injector(object):
def __getattr__(self, x):
return self.registry.__getattribute__(x)
def register(self, module_name):
def register(self, module_name, check=None):
if isinstance(module_name, tuple):
module_name, key_name = (str(module_name[0]), str(module_name[1]))
else:
module_name = str(module_name)
key_name = module_name
try:
loaded_module = sys.modules[module_name]
self.registry[key_name] = loaded_module
except KeyError:
try:
loaded_module = recursive_import(module_name)
# print module_name
# fp, pathname, description = imp.find_module(module_name)
# print pathname
# loaded_module = \
# imp.load_module(module_name, fp, pathname, description)
#
# print loaded_module
self.registry[key_name] = loaded_module
except ImportError:
pass
# finally:
# # Since we may exit via an exception, close fp explicitly.
# try:
# fp.close()
# except (UnboundLocalError, AttributeError):
# pass
if loaded_module is not None:
if check is not None:
check_passed = check(loaded_module)
else:
check_passed = True
if check_passed is True:
self.registry[key_name] = loaded_module
def unregister(self, module_name):
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment