Commit de0c9653 authored by Martin Reinecke's avatar Martin Reinecke
Browse files

remove dependency_injector

parent 019e8ff5
......@@ -24,8 +24,7 @@ from .version import __version__
from keepers import MPILogger
logger = MPILogger()
from .config import dependency_injector,\
nifty_configuration
from .config import nifty_configuration
logger.logger.setLevel(nifty_configuration['loglevel'])
......
......@@ -17,5 +17,4 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
from .nifty_config import dependency_injector,\
nifty_configuration
from .nifty_config import nifty_configuration
......@@ -23,28 +23,9 @@ import os
import numpy as np
import keepers
__all__ = ['dependency_injector', 'nifty_configuration']
# Setup the dependency injector
dependency_injector = keepers.DependencyInjector(
[('mpi4py.MPI', 'MPI'),
('pyfftw', 'fftw'),
'pyHealpix',
'plotly'])
def _fft_module_checker(z):
if z == 'fftw':
return ('fftw' in dependency_injector)
if z == 'numpy':
return True
return False
__all__ = ['nifty_configuration']
# Initialize the variables
variable_fft_module = keepers.Variable(
'fft_module',
['fftw', 'numpy'],
_fft_module_checker)
variable_harmonic_rg_base = keepers.Variable(
'harmonic_rg_base',
......@@ -60,8 +41,7 @@ variable_loglevel = keepers.Variable(
nifty_configuration = keepers.get_Configuration(
name='NIFTy',
variables=[variable_fft_module,
variable_harmonic_rg_base,
variables=[variable_harmonic_rg_base,
variable_loglevel],
file_name='NIFTy.conf',
search_paths=[os.path.expanduser('~') + "/.config/nifty/",
......
......@@ -61,14 +61,6 @@ class FFTOperator(LinearOperator):
For GLSpace, HPSpace, and LMSpace, a sensible (but not unique)
co-domain is chosen that should work satisfactorily in most situations,
but for full control, the user should explicitly specify a codomain.
module: String (optional)
Software module employed for carrying out the transform operations.
For RGSpace pairs this can be "scalar" or "mpi", where "scalar" is
always available (using pyfftw if available, else numpy.fft), and "mpi"
requires pyfftw and offers MPI parallelization.
For sphere-related domains, only "pyHealpix" is
available. If omitted, "fftw" is selected for RGSpaces if available,
else "numpy"; on the sphere the default is "pyHealpix".
domain_dtype: data type (optional)
Data type of the fields that go into "times" and come out of
"adjoint_times". Default is "numpy.complex".
......@@ -112,7 +104,7 @@ class FFTOperator(LinearOperator):
# ---Overwritten properties and methods---
def __init__(self, domain, target=None, module=None,
def __init__(self, domain, target=None,
domain_dtype=None, target_dtype=None, default_spaces=None):
super(FFTOperator, self).__init__(default_spaces)
......@@ -136,10 +128,10 @@ class FFTOperator(LinearOperator):
(self.target[0].__class__, self.domain[0].__class__)]
self._forward_transformation = TransformationCache.create(
forward_class, self.domain[0], self.target[0], module=module)
forward_class, self.domain[0], self.target[0])
self._backward_transformation = TransformationCache.create(
backward_class, self.target[0], self.domain[0], module=module)
backward_class, self.target[0], self.domain[0])
# Store the dtype information
self.domain_dtype = \
......
......@@ -18,30 +18,19 @@
import numpy as np
from ....config import dependency_injector as gdi
from .... import GLSpace, LMSpace
from .slicing_transformation import SlicingTransformation
from . import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix')
import pyHealpix
class GLLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if pyHealpix is None:
raise ImportError(
"The module pyHealpix is needed but not available.")
super(GLLMTransformation, self).__init__(domain, codomain, module)
def __init__(self, domain, codomain=None):
super(GLLMTransformation, self).__init__(domain, codomain)
# ---Mandatory properties and methods---
......
......@@ -18,31 +18,20 @@
import numpy as np
from ....config import dependency_injector as gdi
from .... import HPSpace, LMSpace
from .slicing_transformation import SlicingTransformation
from . import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix')
import pyHealpix
class HPLMTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if pyHealpix is None:
raise ImportError(
"The module pyHealpix is needed but not available")
super(HPLMTransformation, self).__init__(domain, codomain, module)
def __init__(self, domain, codomain=None):
super(HPLMTransformation, self).__init__(domain, codomain)
# ---Mandatory properties and methods---
......
......@@ -17,31 +17,20 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ....config import dependency_injector as gdi
from .... import GLSpace, LMSpace
from .slicing_transformation import SlicingTransformation
from . import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix')
import pyHealpix
class LMGLTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if pyHealpix is None:
raise ImportError(
"The module pyHealpix is needed but not available.")
super(LMGLTransformation, self).__init__(domain, codomain, module)
def __init__(self, domain, codomain=None):
super(LMGLTransformation, self).__init__(domain, codomain)
# ---Mandatory properties and methods---
......
......@@ -17,30 +17,19 @@
# and financially supported by the Studienstiftung des deutschen Volkes.
import numpy as np
from ....config import dependency_injector as gdi
from .... import HPSpace, LMSpace
from .slicing_transformation import SlicingTransformation
from . import lm_transformation_helper
pyHealpix = gdi.get('pyHealpix')
import pyHealpix
class LMHPTransformation(SlicingTransformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
if module is None:
module = 'pyHealpix'
if module != 'pyHealpix':
raise ValueError("Unsupported SHT module.")
if pyHealpix is None:
raise ImportError(
"The module pyHealpix is needed but not available.")
super(LMHPTransformation, self).__init__(domain, codomain, module)
def __init__(self, domain, codomain=None):
super(LMHPTransformation, self).__init__(domain, codomain)
# ---Mandatory properties and methods---
......
......@@ -21,14 +21,12 @@ from builtins import object
import warnings
import numpy as np
from ....config import dependency_injector as gdi
from ....config import nifty_configuration as gc
from .... import nifty_utilities as utilities
from keepers import Loggable
from functools import reduce
fftw = gdi.get('fftw')
import pyfftw
class Transform(Loggable, object):
......@@ -208,17 +206,10 @@ class SerialFFT(Transform):
The numpy fft pendant of a fft object.
"""
def __init__(self, domain, codomain, use_fftw):
def __init__(self, domain, codomain):
super(SerialFFT, self).__init__(domain, codomain)
if use_fftw and (fftw is None):
raise ImportError(
"The serial FFTW module is needed but not available.")
self._use_fftw = use_fftw
# Enable caching
if self._use_fftw:
fftw.interfaces.cache.enable()
pyfftw.interfaces.cache.enable()
def transform(self, val, axes, **kwargs):
"""
......@@ -274,18 +265,12 @@ class SerialFFT(Transform):
local_val = self._apply_mask(temp_val, mask, axes)
# perform the transformation
if self._use_fftw:
if self.codomain.harmonic:
result_val = fftw.interfaces.numpy_fft.fftn(
local_val, axes=axes)
else:
result_val = fftw.interfaces.numpy_fft.ifftn(
local_val, axes=axes)
if self.codomain.harmonic:
result_val = pyfftw.interfaces.numpy_fft.fftn(
local_val, axes=axes)
else:
if self.codomain.harmonic:
result_val = np.fft.fftn(local_val, axes=axes)
else:
result_val = np.fft.ifftn(local_val, axes=axes)
result_val = pyfftw.interfaces.numpy_fft.ifftn(
local_val, axes=axes)
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.zerocenter):
......
......@@ -27,20 +27,9 @@ class RGRGTransformation(Transformation):
# ---Overwritten properties and methods---
def __init__(self, domain, codomain=None, module=None):
super(RGRGTransformation, self).__init__(domain, codomain, module)
if module is None:
module = nifty_configuration['fft_module']
if module == 'fftw':
self._transform = SerialFFT(self.domain, self.codomain,
use_fftw=True)
elif module == 'numpy':
self._transform = SerialFFT(self.domain, self.codomain,
use_fftw=False)
else:
raise ValueError('Unsupported FFT module:' + module)
def __init__(self, domain, codomain=None):
super(RGRGTransformation, self).__init__(domain, codomain)
self._transform = SerialFFT(self.domain, self.codomain)
self.harmonic_base = nifty_configuration['harmonic_rg_base']
......
......@@ -28,7 +28,7 @@ class Transformation(with_metaclass(abc.ABCMeta, type('NewBase', (Loggable, obje
method for all transforms.
"""
def __init__(self, domain, codomain, module=None):
def __init__(self, domain, codomain):
if codomain is None:
self.domain = domain
self.codomain = self.get_codomain(domain)
......
......@@ -22,10 +22,10 @@ class _TransformationCache(object):
def __init__(self):
self.cache = {}
def create(self, transformation_class, domain, codomain, module):
key = (domain, codomain, module)
def create(self, transformation_class, domain, codomain):
key = (domain, codomain)
if key not in self.cache:
self.cache[key] = transformation_class(domain, codomain, module)
self.cache[key] = transformation_class(domain, codomain)
return self.cache[key]
......
......@@ -3,18 +3,15 @@
from builtins import map
from builtins import str
import numpy as np
from ... import dependency_injector as gdi
from .figure_base import FigureBase
from .figure_3D import Figure3D
plotly = gdi.get('plotly')
import plotly
# TODO: add nice height and width defaults for multifigure
class MultiFigure(FigureBase):
def __init__(self, subfigures, title=None, width=None, height=None):
if 'plotly' not in gdi:
raise ImportError("The module plotly is needed but not available.")
super(MultiFigure, self).__init__(title, width, height)
if subfigures is not None:
self.subfigures = np.asarray(subfigures, dtype=np.object)
......
# -*- coding: utf-8 -*-
from .... import dependency_injector as gdi
from .heatmap import Heatmap
import numpy as np
......@@ -8,16 +7,13 @@ from ...descriptors import Axis
from .mollweide_helper import mollweide_helper
pyHealpix = gdi.get('pyHealpix')
import pyHealpix
class GLMollweide(Heatmap):
def __init__(self, data, xsize=800, color_map=None,
webgl=False, smoothing=False, zmin=None, zmax=None):
# smoothing 'best', 'fast', False
if pyHealpix is None:
raise ImportError(
"The module pyHealpix is needed but not available.")
self.xsize = xsize
super(GLMollweide, self).__init__(data, color_map, webgl, smoothing,
......
# -*- coding: utf-8 -*-
from __future__ import division
from .... import dependency_injector as gdi
from .heatmap import Heatmap
import numpy as np
......@@ -9,15 +8,12 @@ from ...descriptors import Axis
from .mollweide_helper import mollweide_helper
pyHealpix = gdi.get('pyHealpix')
import pyHealpix
class HPMollweide(Heatmap):
def __init__(self, data, xsize=800, color_map=None, webgl=False,
smoothing=False, zmin=None, zmax=None): # smoothing 'best', 'fast', False
if pyHealpix is None:
raise ImportError(
"The module pyHealpix is needed but not available.")
self.xsize = xsize
super(HPMollweide, self).__init__(data, color_map, webgl, smoothing,
zmin, zmax)
......
......@@ -13,8 +13,6 @@ import d2o
from keepers import Loggable
from ...config import dependency_injector as gdi
from ...spaces.space import Space
from ...field import Field
from ... import nifty_utilities as utilities
......@@ -22,14 +20,11 @@ from ... import nifty_utilities as utilities
from ..figures import MultiFigure
from future.utils import with_metaclass
plotly = gdi.get('plotly')
import plotly
if plotly is not None and 'IPython' in sys.modules:
plotly.offline.init_notebook_mode()
rank = d2o.config.dependency_injector[
d2o.configuration['mpi_module']].COMM_WORLD.rank
class PlotterBase(with_metaclass(abc.ABCMeta, type('NewBase', (Loggable, object), {}))):
def __init__(self, interactive=False, path='plot.html', title=""):
......
......@@ -22,10 +22,7 @@ import itertools
import numpy as np
from ..space import Space
from ...config import dependency_injector as gdi
pyHealpix = gdi.get('pyHealpix')
import pyHealpix
class GLSpace(Space):
"""
......
......@@ -19,7 +19,6 @@
from builtins import str
from parameterized import parameterized
from nifty import RGSpace, LMSpace, HPSpace, GLSpace, PowerSpace
from nifty.config import dependency_injector as gdi
def custom_name_func(testcase_func, param_num, param):
......@@ -36,9 +35,7 @@ def expand(*args, **kwargs):
def generate_spaces():
spaces = [RGSpace(4), PowerSpace(RGSpace((4, 4), harmonic=True)),
LMSpace(5), HPSpace(4)]
if 'pyHealpix' in gdi:
spaces.append(GLSpace(4))
LMSpace(5), HPSpace(4), GLSpace(4)]
return spaces
......
......@@ -21,7 +21,6 @@ import unittest
import numpy as np
from numpy.testing import assert_equal,\
assert_allclose
from nifty.config import dependency_injector as gdi
from nifty import Field,\
RGSpace,\
LMSpace,\
......@@ -41,17 +40,14 @@ def _get_rtol(tp):
class FFTOperatorTests(unittest.TestCase):
@expand(product(["numpy", "fftw"],
[16, ], [0.1, 1, 3.7],
@expand(product([16, ], [0.1, 1, 3.7],
[np.float64, np.float32, np.complex64, np.complex128],
['real', 'complex']))
def test_fft1D(self, module, dim1, d, itp, base):
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
def test_fft1D(self, dim1, d, itp, base):
tol = _get_rtol(itp)
a = RGSpace(dim1, distances=d)
b = RGSpace(dim1, distances=1./(dim1*d), harmonic=True)
fft = FFTOperator(domain=a, target=b, module=module)
fft = FFTOperator(domain=a, target=b)
fft._forward_transformation.harmonic_base = base
fft._backward_transformation.harmonic_base = base
......@@ -61,20 +57,17 @@ class FFTOperatorTests(unittest.TestCase):
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["numpy", "fftw"],
[12, 15], [9, 12], [0.1, 1, 3.7],
@expand(product([12, 15], [9, 12], [0.1, 1, 3.7],
[0.4, 1, 2.7],
[np.float64, np.float32, np.complex64, np.complex128],
['real', 'complex']))
def test_fft2D(self, module, dim1, dim2, d1, d2,
def test_fft2D(self, dim1, dim2, d1, d2,
itp, base):
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(itp)
a = RGSpace([dim1, dim2], distances=[d1, d2])
b = RGSpace([dim1, dim2],
distances=[1./(dim1*d1), 1./(dim2*d2)], harmonic=True)
fft = FFTOperator(domain=a, target=b, module=module)
fft = FFTOperator(domain=a, target=b)
fft._forward_transformation.harmonic_base = base
fft._backward_transformation.harmonic_base = base
......@@ -83,17 +76,14 @@ class FFTOperatorTests(unittest.TestCase):
out = fft.adjoint_times(fft.times(inp))
assert_allclose(inp.val, out.val, rtol=tol, atol=tol)
@expand(product(["numpy", "fftw"],
[0, 1, 2],
@expand(product([0, 1, 2],
[np.float64, np.float32, np.complex64, np.complex128],
['real', 'complex']))
def test_composed_fft(self, module, index, dtype,
def test_composed_fft(self, index, dtype,
base):
if module == "fftw" and "fftw" not in gdi:
raise SkipTest
tol = _get_rtol(dtype)
a = [a1, a2, a3] = [RGSpace((32,)), RGSpace((4, 4)), RGSpace((5, 6))]
fft = FFTOperator(domain=a[index], module=module,
fft = FFTOperator(domain=a[index],
default_spaces=(index,))
fft._forward_transformation.harmonic_base = base
fft._backward_transformation.harmonic_base = base
......@@ -106,8 +96,6 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([0, 3, 6, 11, 30],
[np.float64, np.float32, np.complex64, np.complex128]))
def test_sht(self, lm, tp):
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = GLSpace(nlat=lm+1)
......@@ -120,8 +108,6 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.float32, np.complex64, np.complex128]))
def test_sht2(self, lm, tp):
if 'pyHealpix' not in gdi:
raise SkipTest
a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2)
fft = FFTOperator(domain=a, target=b)
......@@ -133,8 +119,6 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.float32, np.complex64, np.complex128]))
def test_dotsht(self, lm, tp):
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = GLSpace(nlat=lm+1)
......@@ -149,8 +133,6 @@ class FFTOperatorTests(unittest.TestCase):
@expand(product([128, 256],
[np.float64, np.float32, np.complex64, np.complex128]))
def test_dotsht2(self, lm, tp):
if 'pyHealpix' not in gdi:
raise SkipTest
tol = _get_rtol(tp)
a = LMSpace(lmax=lm)
b = HPSpace(nside=lm//2)
......
......@@ -24,7 +24,6 @@ from numpy.testing import assert_, assert_equal, assert_raises,\
assert_almost_equal
from nose.plugins.skip import SkipTest
from nifty import GLSpace
from nifty.config import dependency_injector as gdi