Commit 01c352f0 authored by theos's avatar theos
Browse files

Removed paradict and moved the functionality into the spaces as properties.

Reordered methods in space classes.
Made Space.weight to an abc.abstractmethod.
parent 5adfb1c8
...@@ -39,7 +39,6 @@ from d2o import distributed_data_object, d2o_librarian ...@@ -39,7 +39,6 @@ from d2o import distributed_data_object, d2o_librarian
from nifty_cmaps import ncmap from nifty_cmaps import ncmap
from field import Field from field import Field
from paradict import Paradict
# this line exists for compatibility reasons # this line exists for compatibility reasons
# TODO: Remove this once the transition to field types is done. # TODO: Remove this once the transition to field types is done.
......
...@@ -23,6 +23,8 @@ from __future__ import division ...@@ -23,6 +23,8 @@ from __future__ import division
from linear_operator import LinearOperator from linear_operator import LinearOperator
from diagonal_operator import DiagonalOperator
from endomorphic_operator import EndomorphicOperator from endomorphic_operator import EndomorphicOperator
from fft_operator import * from fft_operator import *
......
...@@ -6,7 +6,6 @@ from nifty.operators.linear_operator import LinearOperator ...@@ -6,7 +6,6 @@ from nifty.operators.linear_operator import LinearOperator
class EndomorphicOperator(LinearOperator): class EndomorphicOperator(LinearOperator):
__metaclass__ = abc.ABCMeta
# ---Overwritten properties and methods--- # ---Overwritten properties and methods---
......
...@@ -44,7 +44,7 @@ class GLLMTransformation(Transformation): ...@@ -44,7 +44,7 @@ class GLLMTransformation(Transformation):
if not isinstance(domain, GLSpace): if not isinstance(domain, GLSpace):
raise TypeError('ERROR: domain needs to be a GLSpace') raise TypeError('ERROR: domain needs to be a GLSpace')
nlat = domain.paradict['nlat'] nlat = domain.nlat
lmax = nlat - 1 lmax = nlat - 1
mmax = nlat - 1 mmax = nlat - 1
if domain.dtype == np.dtype('float32'): if domain.dtype == np.dtype('float32'):
...@@ -63,10 +63,10 @@ class GLLMTransformation(Transformation): ...@@ -63,10 +63,10 @@ class GLLMTransformation(Transformation):
if not isinstance(codomain, LMSpace): if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.') raise TypeError('ERROR: codomain must be a LMSpace.')
nlat = domain.paradict['nlat'] nlat = domain.nlat
nlon = domain.paradict['nlon'] nlon = domain.nlon
lmax = codomain.paradict['lmax'] lmax = codomain.lmax
mmax = codomain.paradict['mmax'] mmax = codomain.mmax
if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax): if (nlon != 2 * nlat - 1) or (lmax != nlat - 1) or (lmax != mmax):
return False return False
...@@ -90,10 +90,10 @@ class GLLMTransformation(Transformation): ...@@ -90,10 +90,10 @@ class GLLMTransformation(Transformation):
val = self.domain.weight(val, power=-0.5, axes=axes) val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters # shorthands for transform parameters
nlat = self.domain.paradict['nlat'] nlat = self.domain.nlat
nlon = self.domain.paradict['nlon'] nlon = self.domain.nlon
lmax = self.codomain.paradict['lmax'] lmax = self.codomain.lmax
mmax = self.codomain.paradict['mmax'] mmax = self.codomain.mmax
if isinstance(val, distributed_data_object): if isinstance(val, distributed_data_object):
temp_val = val.get_full_data() temp_val = val.get_full_data()
......
...@@ -44,7 +44,7 @@ class HPLMTransformation(Transformation): ...@@ -44,7 +44,7 @@ class HPLMTransformation(Transformation):
if not isinstance(domain, HPSpace): if not isinstance(domain, HPSpace):
raise TypeError('ERROR: domain needs to be a HPSpace') raise TypeError('ERROR: domain needs to be a HPSpace')
lmax = 3 * domain.paradict['nside'] - 1 lmax = 3 * domain.nside - 1
mmax = lmax mmax = lmax
return LMSpace(lmax=lmax, mmax=mmax, dtype=np.dtype('complex128')) return LMSpace(lmax=lmax, mmax=mmax, dtype=np.dtype('complex128'))
...@@ -59,9 +59,9 @@ class HPLMTransformation(Transformation): ...@@ -59,9 +59,9 @@ class HPLMTransformation(Transformation):
if not isinstance(codomain, LMSpace): if not isinstance(codomain, LMSpace):
raise TypeError('ERROR: codomain must be a LMSpace.') raise TypeError('ERROR: codomain must be a LMSpace.')
nside = domain.paradict['nside'] nside = domain.nside
lmax = codomain.paradict['lmax'] lmax = codomain.lmax
mmax = codomain.paradict['mmax'] mmax = codomain.mmax
if (3 * nside - 1 != lmax) or (lmax != mmax): if (3 * nside - 1 != lmax) or (lmax != mmax):
return False return False
...@@ -88,8 +88,8 @@ class HPLMTransformation(Transformation): ...@@ -88,8 +88,8 @@ class HPLMTransformation(Transformation):
val = self.domain.weight(val, power=-0.5, axes=axes) val = self.domain.weight(val, power=-0.5, axes=axes)
# shorthands for transform parameters # shorthands for transform parameters
lmax = self.codomain.paradict['lmax'] lmax = self.codomain.lmax
mmax = self.codomain.paradict['mmax'] mmax = self.codomain.mmax
if isinstance(val, distributed_data_object): if isinstance(val, distributed_data_object):
temp_val = val.get_full_data() temp_val = val.get_full_data()
......
...@@ -59,8 +59,8 @@ class LMGLTransformation(Transformation): ...@@ -59,8 +59,8 @@ class LMGLTransformation(Transformation):
else: else:
raise ValueError('ERROR: unsupported domain dtype') raise ValueError('ERROR: unsupported domain dtype')
nlat = domain.paradict['lmax'] + 1 nlat = domain.lmax + 1
nlon = domain.paradict['lmax'] * 2 + 1 nlon = domain.lmax * 2 + 1
return GLSpace(nlat=nlat, nlon=nlon, dtype=new_dtype) return GLSpace(nlat=nlat, nlon=nlon, dtype=new_dtype)
@staticmethod @staticmethod
...@@ -74,10 +74,10 @@ class LMGLTransformation(Transformation): ...@@ -74,10 +74,10 @@ class LMGLTransformation(Transformation):
if not isinstance(codomain, GLSpace): if not isinstance(codomain, GLSpace):
raise TypeError('ERROR: codomain must be a GLSpace.') raise TypeError('ERROR: codomain must be a GLSpace.')
nlat = codomain.paradict['nlat'] nlat = codomain.nlat
nlon = codomain.paradict['nlon'] nlon = codomain.nlon
lmax = domain.paradict['lmax'] lmax = domain.lmax
mmax = domain.paradict['mmax'] mmax = domain.mmax
if (lmax != mmax) or (nlat != lmax + 1) or (nlon != 2 * lmax + 1): if (lmax != mmax) or (nlat != lmax + 1) or (nlon != 2 * lmax + 1):
return False return False
...@@ -112,10 +112,10 @@ class LMGLTransformation(Transformation): ...@@ -112,10 +112,10 @@ class LMGLTransformation(Transformation):
return_val = np.empty_like(temp_val) return_val = np.empty_like(temp_val)
inp = temp_val[slice_list] inp = temp_val[slice_list]
nlat = self.codomain.paradict['nlat'] nlat = self.codomain.nlat
nlon = self.codomain.paradict['nlon'] nlon = self.codomain.nlon
lmax = self.domain.paradict['lmax'] lmax = self.domain.lmax
mmax = self.paradict['mmax'] mmax = self.mmax
if self.domain.dtype == np.dtype('complex64'): if self.domain.dtype == np.dtype('complex64'):
inp = gl.alm2map_f(inp, nlat=nlat, nlon=nlon, inp = gl.alm2map_f(inp, nlat=nlat, nlon=nlon,
......
...@@ -48,7 +48,7 @@ class LMHPTransformation(Transformation): ...@@ -48,7 +48,7 @@ class LMHPTransformation(Transformation):
if not isinstance(domain, LMSpace): if not isinstance(domain, LMSpace):
raise TypeError('ERROR: domain needs to be a LMSpace') raise TypeError('ERROR: domain needs to be a LMSpace')
nside = (domain.paradict['lmax'] + 1) // 3 nside = (domain.lmax + 1) // 3
return HPSpace(nside=nside) return HPSpace(nside=nside)
@staticmethod @staticmethod
...@@ -61,9 +61,9 @@ class LMHPTransformation(Transformation): ...@@ -61,9 +61,9 @@ class LMHPTransformation(Transformation):
if not isinstance(codomain, HPSpace): if not isinstance(codomain, HPSpace):
raise TypeError('ERROR: codomain must be a HPSpace.') raise TypeError('ERROR: codomain must be a HPSpace.')
nside = codomain.paradict['nside'] nside = codomain.nside
lmax = domain.paradict['lmax'] lmax = domain.lmax
mmax = domain.paradict['mmax'] mmax = domain.mmax
if (lmax != mmax) or (3 * nside - 1 != lmax): if (lmax != mmax) or (3 * nside - 1 != lmax):
return False return False
...@@ -98,9 +98,9 @@ class LMHPTransformation(Transformation): ...@@ -98,9 +98,9 @@ class LMHPTransformation(Transformation):
return_val = np.empty_like(temp_val) return_val = np.empty_like(temp_val)
inp = temp_val[slice_list] inp = temp_val[slice_list]
nside = self.codomain.paradict['nside'] nside = self.codomain.nside
lmax = self.domain.paradict['lmax'] lmax = self.domain.lmax
mmax = self.domain.paradict['mmax'] mmax = self.domain.mmax
inp = inp.astype(np.complex128, copy=False) inp = inp.astype(np.complex128, copy=False)
inp = hp.alm2map(inp, nside, lmax=lmax, mmax=mmax, inp = hp.alm2map(inp, nside, lmax=lmax, mmax=mmax,
......
...@@ -218,7 +218,7 @@ class FFTW(Transform): ...@@ -218,7 +218,7 @@ class FFTW(Transform):
def _atomic_mpi_transform(self, val, info, axes): def _atomic_mpi_transform(self, val, info, axes):
# Apply codomain centering mask # Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.paradict['zerocenter']): if reduce(lambda x, y: x + y, self.codomain.zerocenter):
temp_val = np.copy(val) temp_val = np.copy(val)
val = self._apply_mask(temp_val, info.cmask_codomain, axes) val = self._apply_mask(temp_val, info.cmask_codomain, axes)
...@@ -235,7 +235,7 @@ class FFTW(Transform): ...@@ -235,7 +235,7 @@ class FFTW(Transform):
return None return None
# Apply domain centering mask # Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.paradict['zerocenter']): if reduce(lambda x, y: x + y, self.domain.zerocenter):
result = self._apply_mask(result, info.cmask_domain, axes) result = self._apply_mask(result, info.cmask_domain, axes)
# Correct the sign if needed # Correct the sign if needed
...@@ -263,7 +263,7 @@ class FFTW(Transform): ...@@ -263,7 +263,7 @@ class FFTW(Transform):
**kwargs) **kwargs)
# Apply codomain centering mask # Apply codomain centering mask
if reduce(lambda x, y: x + y, self.codomain.paradict['zerocenter']): if reduce(lambda x, y: x + y, self.codomain.zerocenter):
temp_val = np.copy(local_val) temp_val = np.copy(local_val)
local_val = self._apply_mask(temp_val, local_val = self._apply_mask(temp_val,
current_info.cmask_codomain, axes) current_info.cmask_codomain, axes)
...@@ -275,7 +275,7 @@ class FFTW(Transform): ...@@ -275,7 +275,7 @@ class FFTW(Transform):
) )
# Apply domain centering mask # Apply domain centering mask
if reduce(lambda x, y: x + y, self.domain.paradict['zerocenter']): if reduce(lambda x, y: x + y, self.domain.zerocenter):
local_result = self._apply_mask(local_result, local_result = self._apply_mask(local_result,
current_info.cmask_domain, axes) current_info.cmask_domain, axes)
...@@ -446,19 +446,19 @@ class FFTWTransformInfo(object): ...@@ -446,19 +446,19 @@ class FFTWTransformInfo(object):
raise ImportError("The module pyfftw is needed but not available.") raise ImportError("The module pyfftw is needed but not available.")
self.cmask_domain = fftw_context.get_centering_mask( self.cmask_domain = fftw_context.get_centering_mask(
domain.paradict['zerocenter'], domain.zerocenter,
local_shape, local_shape,
local_offset_Q) local_offset_Q)
self.cmask_codomain = fftw_context.get_centering_mask( self.cmask_codomain = fftw_context.get_centering_mask(
codomain.paradict['zerocenter'], codomain.zerocenter,
local_shape, local_shape,
local_offset_Q) local_offset_Q)
# If both domain and codomain are zero-centered the result, # If both domain and codomain are zero-centered the result,
# will get a global minus. Store the sign to correct it. # will get a global minus. Store the sign to correct it.
self.sign = (-1) ** np.sum(np.array(domain.paradict['zerocenter']) * self.sign = (-1) ** np.sum(np.array(domain.zerocenter) *
np.array(codomain.paradict['zerocenter']) * np.array(codomain.zerocenter) *
(np.array(domain.shape) // 2 % 2)) (np.array(domain.shape) // 2 % 2))
@property @property
...@@ -611,13 +611,13 @@ class GFFT(Transform): ...@@ -611,13 +611,13 @@ class GFFT(Transform):
out_ax=[], out_ax=[],
ftmachine='fft' if self.codomain.harmonic else 'ifft', ftmachine='fft' if self.codomain.harmonic else 'ifft',
in_zero_center=map( in_zero_center=map(
bool, self.domain.paradict['zerocenter'] bool, self.domain.zerocenter
), ),
out_zero_center=map( out_zero_center=map(
bool, self.codomain.paradict['zerocenter'] bool, self.codomain.zerocenter
), ),
enforce_hermitian_symmetry=bool( enforce_hermitian_symmetry=bool(
self.codomain.paradict['complexity'] self.codomain.complexity
), ),
W=-1, W=-1,
alpha=-1, alpha=-1,
......
...@@ -62,20 +62,20 @@ class RGRGTransformation(Transformation): ...@@ -62,20 +62,20 @@ class RGRGTransformation(Transformation):
# parse the cozerocenter input # parse the cozerocenter input
if zerocenter is None: if zerocenter is None:
zerocenter = domain.paradict['zerocenter'] zerocenter = domain.zerocenter
# if the input is something scalar, cast it to a boolean # if the input is something scalar, cast it to a boolean
else: else:
temp = np.empty_like(domain.paradict['zerocenter']) temp = np.empty_like(domain.zerocenter)
temp[:] = zerocenter temp[:] = zerocenter
zerocenter = temp zerocenter = temp
# calculate the initialization parameters # calculate the initialization parameters
distances = 1 / (np.array(domain.paradict['shape']) * distances = 1 / (np.array(domain.shape) *
np.array(domain.paradict['distances'])) np.array(domain.distances))
if dtype is None: if dtype is None:
dtype = np.complex dtype = np.complex
new_space = RGSpace(domain.paradict['shape'], new_space = RGSpace(domain.shape,
zerocenter=zerocenter, zerocenter=zerocenter,
distances=distances, distances=distances,
harmonic=(not domain.harmonic), harmonic=(not domain.harmonic),
...@@ -94,8 +94,8 @@ class RGRGTransformation(Transformation): ...@@ -94,8 +94,8 @@ class RGRGTransformation(Transformation):
if not isinstance(codomain, RGSpace): if not isinstance(codomain, RGSpace):
return False return False
if not np.all(np.array(domain.paradict['shape']) == if not np.all(np.array(domain.shape) ==
np.array(codomain.paradict['shape'])): np.array(codomain.shape)):
return False return False
if domain.harmonic == codomain.harmonic: if domain.harmonic == codomain.harmonic:
...@@ -103,9 +103,9 @@ class RGRGTransformation(Transformation): ...@@ -103,9 +103,9 @@ class RGRGTransformation(Transformation):
# Check if the distances match, i.e. dist' = 1 / (num * dist) # Check if the distances match, i.e. dist' = 1 / (num * dist)
if not np.all( if not np.all(
np.absolute(np.array(domain.paradict['shape']) * np.absolute(np.array(domain.shape) *
np.array(domain.paradict['distances']) * np.array(domain.distances) *
np.array(codomain.paradict['distances']) - 1) < np.array(codomain.distances) - 1) <
10**-7): 10**-7):
return False return False
......
...@@ -16,22 +16,6 @@ class LinearOperator(object): ...@@ -16,22 +16,6 @@ class LinearOperator(object):
self._domain = self._parse_domain(domain) self._domain = self._parse_domain(domain)
self._field_type = self._parse_field_type(field_type) self._field_type = self._parse_field_type(field_type)
@property
def domain(self):
return self._domain
@abc.abstractproperty
def target(self):
raise NotImplementedError
@property
def field_type(self):
return self._field_type
@abc.abstractproperty
def field_type_target(self):
raise NotImplementedError
def _parse_domain(self, domain): def _parse_domain(self, domain):
if domain is None: if domain is None:
domain = () domain = ()
...@@ -61,6 +45,22 @@ class LinearOperator(object): ...@@ -61,6 +45,22 @@ class LinearOperator(object):
"ERROR: Given object is not a nifty.FieldType.")) "ERROR: Given object is not a nifty.FieldType."))
return field_type return field_type
@property
def domain(self):
return self._domain
@abc.abstractproperty
def target(self):
raise NotImplementedError
@property
def field_type(self):
return self._field_type
@abc.abstractproperty
def field_type_target(self):
raise NotImplementedError
@abc.abstractproperty @abc.abstractproperty
def implemented(self): def implemented(self):
raise NotImplementedError raise NotImplementedError
......
# -*- coding: utf-8 -*-
class Paradict(object):
def __init__(self, **kwargs):
if not hasattr(self, 'parameters'):
self.parameters = {}
for key in kwargs:
self[key] = kwargs[key]
def __iter__(self):
return self.parameters.__iter__()
def __eq__(self, other):
return (isinstance(other, self.__class__) and
self.__dict__ == other.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return self.parameters.__repr__()
def __setitem__(self, key, arg):
raise NotImplementedError
def __getitem__(self, key):
return self.parameters.__getitem__(key)
def __hash__(self):
result_hash = 0
for (key, item) in self.parameters.items():
try:
temp_hash = hash(item)
except TypeError:
temp_hash = hash(tuple(item))
result_hash ^= temp_hash ^ int(hash(key)/131)
return result_hash
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from gl_space import GLSpace,\ from gl_space import GLSpace
GLSpaceParadict from hp_space import HPSpace
from lm_space import LMSpace
from hp_space import HPSpace,\ from power_space import PowerSpace
HPSpaceParadict from rg_space import RGSpace
from space import Space
from lm_space import LMSpace,\ \ No newline at end of file
LMSpaceParadict
from power_space import PowerSpace,\
PowerSpaceParadict
from rg_space import RGSpace,\
RGSpaceParadict
from space import Space,\
SpaceParadict
...@@ -2,4 +2,3 @@ ...@@ -2,4 +2,3 @@
from gl_space import GLSpace from gl_space import GLSpace
from gl_space_paradict import GLSpaceParadict
\ No newline at end of file
...@@ -8,7 +8,6 @@ from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES ...@@ -8,7 +8,6 @@ from d2o import STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.spaces.space import Space from nifty.spaces.space import Space
from nifty.config import about, nifty_configuration as gc,\ from nifty.config import about, nifty_configuration as gc,\
dependency_injector as gdi dependency_injector as gdi
from gl_space_paradict import GLSpaceParadict
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
gl = gdi.get('libsharp_wrapper_gl') gl = gdi.get('libsharp_wrapper_gl')
...@@ -69,7 +68,9 @@ class GLSpace(Space): ...@@ -69,7 +68,9 @@ class GLSpace(Space):
An array containing the pixel sizes. An array containing the pixel sizes.
""" """
def __init__(self, nlat, nlon=None, dtype=np.dtype('float')): # ---Overwritten properties and met