Commit 01c352f0 authored by theos's avatar theos
Browse files

Removed paradict and moved the functionality into the spaces as properties.

Reordered methods in space classes.
Made Space.weight to an abc.abstractmethod.
parent 5adfb1c8
......@@ -39,7 +39,6 @@ from d2o import distributed_data_object, d2o_librarian
from nifty_cmaps import ncmap
from field import Field
from paradict import Paradict
# this line exists for compatibility reasons
# TODO: Remove this once the transition to field types is done.
......
......@@ -23,6 +23,8 @@ from __future__ import division
from linear_operator import LinearOperator
from diagonal_operator import DiagonalOperator
from endomorphic_operator import EndomorphicOperator
from fft_operator import *
......
......@@ -6,7 +6,6 @@ from nifty.operators.linear_operator import LinearOperator
class EndomorphicOperator(LinearOperator):
__metaclass__ = abc.ABCMeta
# ---Overwritten properties and methods---
......
......@@ -44,7 +44,7 @@ class GLLMTransformation(Transformation):
if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain needs to be a GLSpace')
nlat = domain.paradict['nlat']
nlat = domain.nlat
lmax = nlat - 1
mmax = nlat - 1
if domain.dtype == np.dtype('float32'):
......@@ -63,10 +63,10 @@ class GLLMTransformation(Transformation):
if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.')
nlat = domain.paradict['nlat']
nlon = domain.paradict['nlon']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
nlat = domain.nlat
nlon = domain.nlon
lmax = codomain.lmax
mmax = codomain.mmax
if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax):
return False
......@@ -90,10 +90,10 @@ class GLLMTransformation(Transformation):
val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters
nlat = self.domain.paradict['nlat']
nlon = self.domain.paradict['nlon']
lmax = self.codomain.paradict['lmax']
mmax = self.codomain.paradict['mmax']
nlat = self.domain.nlat
nlon = self.domain.nlon
lmax = self.codomain.lmax
mmax = self.codomain.mmax
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
......
......@@ -44,7 +44,7 @@ class HPLMTransformation(Transformation):
if not isinstance(domain, HPSpace):
raise TypeError('ERROR: domain needs to be a HPSpace')
lmax = 3 * domain.paradict['nside'] - 1
lmax = 3 * domain.nside - 1
mmax = lmax
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.dtype('complex128'))
......@@ -59,9 +59,9 @@ class HPLMTransformation(Transformation):
if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.')
nside = domain.paradict['nside']
lmax = codomain.paradict['lmax']
mmax = codomain.paradict['mmax']
nside = domain.nside
lmax = codomain.lmax
mmax = codomain.mmax
if (3 * nside - 1 != lmax) or (lmax != mmax):
return False
......@@ -88,8 +88,8 @@ class HPLMTransformation(Transformation):
val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters
lmax = self.codomain.paradict['lmax']
mmax = self.codomain.paradict['mmax']
lmax = self.codomain.lmax
mmax = self.codomain.mmax
if isinstance(val, distributed_data_object):
temp_val = val.get_full_data()
......
......@@ -59,8 +59,8 @@ class LMGLTransformation(Transformation):
else:
raise ValueError('ERROR: unsupported domain dtype')
nlat = domain.paradict['lmax'] + 1
nlon = domain.paradict['lmax'] * 2 + 1
nlat = domain.lmax + 1
nlon = domain.lmax * 2 + 1
return GLSpace(nlat=nlat, nlon=nlon, dtype=new_dtype)
@staticmethod
......@@ -74,10 +74,10 @@ class LMGLTransformation(Transformation):
if not isinstance(codomain, GLSpace):
raise TypeError('ERROR: codomain must be a GLSpace.')
nlat = codomain.paradict['nlat']
nlon = codomain.paradict['nlon']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
nlat = codomain.nlat
nlon = codomain.nlon
lmax = domain.lmax
mmax = domain.mmax
if (lmax != mmax) or (nlat != lmax + 1) or (nlon != 2 * lmax + 1):
return False
......@@ -112,10 +112,10 @@ class LMGLTransformation(Transformation):
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
nlat = self.codomain.paradict['nlat']
nlon = self.codomain.paradict['nlon']
lmax = self.domain.paradict['lmax']
mmax = self.paradict['mmax']
nlat = self.codomain.nlat
nlon = self.codomain.nlon
lmax = self.domain.lmax
mmax = self.mmax
if self.domain.dtype == np.dtype('complex64'):
inp = gl.alm2map_f(inp, nlat=nlat, nlon=nlon,
......
......@@ -48,7 +48,7 @@ class LMHPTransformation(Transformation):
if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain needs to be a LMSpace')
nside = (domain.paradict['lmax'] + 1) // 3
nside = (domain.lmax + 1) // 3
return HPSpace(nside=nside)
@staticmethod
......@@ -61,9 +61,9 @@ class LMHPTransformation(Transformation):
if not isinstance(codomain, HPSpace):
raise TypeError('ERROR: codomain must be a HPSpace.')
nside = codomain.paradict['nside']
lmax = domain.paradict['lmax']
mmax = domain.paradict['mmax']
nside = codomain.nside
lmax = domain.lmax
mmax = domain.mmax
if (lmax != mmax) or (3 * nside - 1 != lmax):
return False
......@@ -98,9 +98,9 @@ class LMHPTransformation(Transformation):
return_val = np.empty_like(temp_val)
inp = temp_val[slice_list]
nside = self.codomain.paradict['nside']
lmax = self.domain.paradict['lmax']
mmax = self.domain.paradict['mmax']
nside = self.codomain.nside
lmax = self.domain.lmax
mmax = self.domain.mmax
inp = inp.astype(np.complex128, copy=False)
inp = hp.alm2map(inp, nside, lmax=lmax, mmax=mmax,
......
......@@ -218,7 +218,7 @@ class FFTW(Transform):
def _atomic_mpi_transform(self, val, info, axes):
# Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.paradict['zerocenter']):
if reduce(lambda x, y: x + y, self.codomain.zerocenter):
temp_val = np.copy(val)
val = self._apply_mask(temp_val, info.cmask_codomain, axes)
......@@ -235,7 +235,7 @@ class FFTW(Transform):
return None
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.paradict['zerocenter']):
if reduce(lambda x, y: x + y, self.domain.zerocenter):
result = self._apply_mask(result, info.cmask_domain, axes)
# Correct the sign if needed
......@@ -263,7 +263,7 @@ class FFTW(Transform):
**kwargs)
# Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.paradict['zerocenter']):
if reduce(lambda x, y: x + y, self.codomain.zerocenter):
temp_val = np.copy(local_val)
local_val = self._apply_mask(temp_val,
current_info.cmask_codomain, axes)
......@@ -275,7 +275,7 @@ class FFTW(Transform):
)
# Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.paradict['zerocenter']):
if reduce(lambda x, y: x + y, self.domain.zerocenter):
local_result = self._apply_mask(local_result,
current_info.cmask_domain, axes)
......@@ -446,19 +446,19 @@ class FFTWTransformInfo(object):
raise ImportError("The module pyfftw is needed but not available.")
self.cmask_domain = fftw_context.get_centering_mask(
domain.paradict['zerocenter'],
domain.zerocenter,
local_shape,
local_offset_Q)
self.cmask_codomain = fftw_context.get_centering_mask(
codomain.paradict['zerocenter'],
codomain.zerocenter,
local_shape,
local_offset_Q)
# If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it.
self.sign = (-1) ** np.sum(np.array(domain.paradict['zerocenter']) *
np.array(codomain.paradict['zerocenter']) *
self.sign = (-1) ** np.sum(np.array(domain.zerocenter) *
np.array(codomain.zerocenter) *
(np.array(domain.shape) // 2 % 2))
@property
......@@ -611,13 +611,13 @@ class GFFT(Transform):
out_ax=[],
ftmachine='fft' if self.codomain.harmonic else 'ifft',
in_zero_center=map(
bool, self.domain.paradict['zerocenter']
bool, self.domain.zerocenter
),
out_zero_center=map(
bool, self.codomain.paradict['zerocenter']
bool, self.codomain.zerocenter
),
enforce_hermitian_symmetry=bool(
self.codomain.paradict['complexity']
self.codomain.complexity
),
W=-1,
alpha=-1,
......
......@@ -62,20 +62,20 @@ class RGRGTransformation(Transformation):
# parse the cozerocenter input
if zerocenter is None:
zerocenter = domain.paradict['zerocenter']
zerocenter = domain.zerocenter
# if the input is something scalar, cast it to a boolean
else:
temp = np.empty_like(domain.paradict['zerocenter'])
temp = np.empty_like(domain.zerocenter)
temp[:] = zerocenter
zerocenter = temp
# calculate the initialization parameters
distances = 1 / (np.array(domain.paradict['shape']) *
np.array(domain.paradict['distances']))
distances = 1 / (np.array(domain.shape) *
np.array(domain.distances))
if dtype is None:
dtype = np.complex
new_space = RGSpace(domain.paradict['shape'],
new_space = RGSpace(domain.shape,
zerocenter=zerocenter,
distances=distances,
harmonic=(not domain.harmonic),
......@@ -94,8 +94,8 @@ class RGRGTransformation(Transformation):
if not isinstance(codomain, RGSpace):
return False
if not np.all(np.array(domain.paradict['shape']) ==
np.array(codomain.paradict['shape'])):
if not np.all(np.array(domain.shape) ==
np.array(codomain.shape)):
return False
if domain.harmonic == codomain.harmonic:
......@@ -103,9 +103,9 @@ class RGRGTransformation(Transformation):
# Check if the distances match, i.e. dist' = 1 / (num * dist)
if not np.all(
np.absolute(np.array(domain.paradict['shape']) *
np.array(domain.paradict['distances']) *
np.array(codomain.paradict['distances']) - 1) <
np.absolute(np.array(domain.shape) *
np.array(domain.distances) *
np.array(codomain.distances) - 1) <
10**-7):
return False
......
......@@ -16,22 +16,6 @@ class LinearOperator(object):
self._domain = self._parse_domain(domain)
self._field_type = self._parse_field_type(field_type)
@property
def domain(self):
return self._domain
@abc.abstractproperty
def target(self):
raise NotImplementedError
@property
def field_type(self):
return self._field_type
@abc.abstractproperty
def field_type_target(self):
raise NotImplementedError
def _parse_domain(self, domain):
if domain is None:
domain = ()
......@@ -61,6 +45,22 @@ class LinearOperator(object):
"ERROR: Given object is not a nifty.FieldType."))
return field_type
@property
def domain(self):
return self._domain
@abc.abstractproperty
def target(self):
raise NotImplementedError
@property
def field_type(self):
return self._field_type
@abc.abstractproperty
def field_type_target(self):
raise NotImplementedError
@abc.abstractproperty
def implemented(self):
raise NotImplementedError
......
# -*- coding: utf-8 -*-
class Paradict(object):
def __init__(self, **kwargs):
if not hasattr(self, 'parameters'):
self.parameters = {}
for key in kwargs:
self[key] = kwargs[key]
def __iter__(self):
return self.parameters.__iter__()
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return self.parameters.__repr__()
def __setitem__(self, key, arg):
raise NotImplementedError
def __getitem__(self, key):
return self.parameters.__getitem__(key)
def __hash__(self):
result_hash = 0
for (key, item) in self.parameters.items():
try:
temp_hash = hash(item)
except TypeError:
temp_hash = hash(tuple(item))
result_hash ^= temp_hash ^ int(hash(key)/131)
return result_hash
# -*- coding: utf-8 -*-
from gl_space import GLSpace,\
GLSpaceParadict
from hp_space import HPSpace,\
HPSpaceParadict
from lm_space import LMSpace,\
LMSpaceParadict
from power_space import PowerSpace,\
PowerSpaceParadict
from rg_space import RGSpace,\
RGSpaceParadict
from space import Space,\
SpaceParadict
from gl_space import GLSpace
from hp_space import HPSpace
from lm_space import LMSpace
from power_space import PowerSpace
from rg_space import RGSpace
from space import Space
\ No newline at end of file
......@@ -2,4 +2,3 @@
from gl_space import GLSpace
from gl_space_paradict import GLSpaceParadict
\ No newline at end of file
......@@ -8,7 +8,6 @@ from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.spaces.space import Space
from nifty.config import about, nifty_configuration as gc,\
dependency_injector as gdi
from gl_space_paradict import GLSpaceParadict
import nifty.nifty_utilities as utilities
gl = gdi.get('libsharp_wrapper_gl')
......@@ -69,7 +68,9 @@ class GLSpace(Space):
An array containing the pixel sizes.
"""
def __init__(self, nlat, nlon=None, dtype=np.dtype('float')):
# ---Overwritten properties and methods---
def __init__(self, nlat=2, nlon=3, dtype=np.dtype('float')):
"""
Sets the attributes for a gl_space class instance.
......@@ -97,34 +98,41 @@ class GLSpace(Space):
# check imports
if not gc['use_libsharp']:
raise ImportError(about._errors.cstring(
"ERROR: libsharp_wrapper_gl not loaded."))
"ERROR: libsharp_wrapper_gl not available or not loaded."))
super(GLSpace, self).__init__(dtype)
# setup paradict
self.paradict = GLSpaceParadict(nlat=nlat, nlon=nlon)
self._nlat = self._parse_nlat(nlat)
self._nlon = self._parse_nlon(nlon)
# setup dtype
self.dtype = np.dtype(dtype)
# ---Mandatory properties and methods---
# GLSpace is not harmonic
self._harmonic = False
@property
def harmonic(self):
return False
@property
def shape(self):
return (np.int((self.paradict['nlat'] * self.paradict['nlon'])),)
return (np.int((self.nlat * self.nlon)),)
@property
def dim(self):
return np.int((self.paradict['nlat'] * self.paradict['nlon']))
return np.int((self.nlat * self.nlon))
@property
def total_volume(self):
return 4 * np.pi
def copy(self):
return self.__class__(nlat=self.nlat,
nlon=self.nlon,
dtype=self.dtype)
def weight(self, x, power=1, axes=None, inplace=False):
axes = utilities.cast_axis_to_tuple(axes, length=1)
nlon = self.paradict['nlon']
nlat = self.paradict['nlat']
nlon = self.nlon
nlat = self.nlat
weight = np.array(list(itertools.chain.from_iterable(
itertools.repeat(x ** power, nlon)
......@@ -145,3 +153,42 @@ class GLSpace(Space):
result_x = x * weight
return result_x
# ---Added properties and methods---
@property
def nlat(self):
return self._nlat
@property
def nlon(self):
return self._nlon
def _parse_nlat(self, nlat):
nlat = int(nlat)
if nlat < 2:
raise ValueError(about._errors.cstring(
"ERROR: nlat must be a positive number."))
elif nlat % 2 != 0:
raise ValueError(about._errors.cstring(
"ERROR: nlat must be a multiple of 2."))
return nlat
def _parse_nlon(self, nlon):
if nlon is None:
nlon = 2 * self.nlat - 1
else:
nlon = int(nlon)
if nlon != 2 * self.nlat - 1:
about.warnings.cprint(
"WARNING: nlon was set to an unrecommended value: "
"nlon <> 2*nlat-1.")
return nlon
# -*- coding: utf-8 -*-
from nifty.config import about
from nifty.spaces.space import SpaceParadict
class GLSpaceParadict(SpaceParadict):
def __init__(self, nlat, nlon):
SpaceParadict.__init__(self, nlat=nlat)
if nlon is None:
nlon = -1
self['nlon'] = nlon
def __setitem__(self, key, arg):
if key not in ['nlat', 'nlon']:
raise ValueError(about._errors.cstring(
"ERROR: Unsupported GLSpace parameter: " + key))
if key == 'nlat':
temp = int(arg)
if(temp < 1):
raise ValueError(about._errors.cstring(
"ERROR: nlat: nonpositive number."))
if (temp % 2 != 0):
raise ValueError(about._errors.cstring(
"ERROR: invalid parameter (nlat <> 2n)."))
try:
if temp < self['mmax']:
about.warnings.cprint(
"WARNING: mmax parameter set to lmax.")
self['mmax'] = temp
if (temp != self['mmax']):
about.warnings.cprint(
"WARNING: unrecommended parameter set (mmax <> lmax).")
except:
pass
elif key == 'nlon':
temp = int(arg)
if (temp < 1):
about.warnings.cprint(
"WARNING: nlon parameter set to default.")
temp = 2 * self['nlat'] - 1
if(temp != 2 * self['nlat'] - 1):
about.warnings.cprint(
"WARNING: unrecommended parameter set (nlon <> 2*nlat-1).")
self.parameters.__setitem__(key, temp)
......@@ -2,4 +2,3 @@
from hp_space import HPSpace
from hp_space_paradict import HPSpaceParadict
......@@ -35,11 +35,10 @@ from __future__ import division