Commit d380119a authored by theos's avatar theos
Browse files

RGPowerSpace can now be initialized and possesses basic space properties and power indices.

parent 7ed80b4d
......@@ -108,7 +108,7 @@ class Field(object):
def __init__(self, domain=None, val=None, codomain=None,
dtype=None, field_type=None, copy=False,
datamodel=None, comm=None, **kwargs):
datamodel=None, **kwargs):
"""
Sets the attributes for a field class instance.
......@@ -137,7 +137,6 @@ class Field(object):
self._init_from_field(f=val,
domain=domain,
codomain=codomain,
comm=comm,
copy=copy,
dtype=dtype,
field_type=field_type,
......@@ -147,14 +146,13 @@ class Field(object):
self._init_from_array(val=val,
domain=domain,
codomain=codomain,
comm=comm,
copy=copy,
dtype=dtype,
field_type=field_type,
datamodel=datamodel,
**kwargs)
def _init_from_field(self, f, domain, codomain, comm, copy, dtype,
def _init_from_field(self, f, domain, codomain, copy, dtype,
field_type, datamodel, **kwargs):
# check domain
if domain is None:
......@@ -177,13 +175,12 @@ class Field(object):
self._init_from_array(domain=domain,
val=f.val,
codomain=codomain,
comm=comm,
copy=copy,
dtype=dtype,
datamodel=datamodel,
**kwargs)
def _init_from_array(self, val, domain, codomain, comm, copy, dtype,
def _init_from_array(self, val, domain, codomain, copy, dtype,
field_type, datamodel, **kwargs):
# check domain
self.domain = self._parse_domain(domain=domain)
......@@ -204,12 +201,7 @@ class Field(object):
field_type=self.field_type)
self.dtype = dtype
if comm is not None:
self.comm = self._parse_comm(comm)
elif isinstance(val, distributed_data_object):
self.comm = val.comm
else:
self.comm = gc['default_comm']
self._comm = getattr(gdi[gc['mpi_module']], gc['default_comm'])
if datamodel in DISTRIBUTION_STRATEGIES['all']:
self.datamodel = datamodel
......@@ -248,25 +240,6 @@ class Field(object):
axes_list += [tuple(l)]
return tuple(axes_list)
def _parse_comm(self, comm):
# check if comm is a string -> the name of comm is given
# -> Extract it from the mpi_module
if isinstance(comm, str):
if gc.validQ('default_comm', comm):
result_comm = getattr(gdi[gc['mpi_module']], comm)
else:
raise ValueError(about._errors.cstring(
"ERROR: The given communicator-name is not supported."))
# check if the given comm object is an instance of default Intracomm
else:
if isinstance(comm, gdi[gc['mpi_module']].Intracomm):
result_comm = comm
else:
raise ValueError(about._errors.cstring(
"ERROR: The given comm object is not an instance of the " +
"default-MPI-module's Intracomm Class."))
return result_comm
def _parse_domain(self, domain):
if domain is None:
domain = ()
......@@ -378,7 +351,7 @@ class Field(object):
self._unary_operation(self.val, op='copy_empty')
return new_field
def copy_empty(self, domain=None, codomain=None, dtype=None, comm=None,
def copy_empty(self, domain=None, codomain=None, dtype=None,
datamodel=None, field_type=None, **kwargs):
if domain is None:
domain = self.domain
......@@ -389,9 +362,6 @@ class Field(object):
if dtype is None:
dtype = self.dtype
if comm is None:
comm = self.comm
if datamodel is None:
datamodel = self.datamodel
......@@ -412,14 +382,13 @@ class Field(object):
_fast_copyable = False
break
if (_fast_copyable and dtype == self.dtype and comm == self.comm and
datamodel == self.datamodel and
kwargs == {}):
if (_fast_copyable and dtype == self.dtype and
datamodel == self.datamodel and kwargs == {}):
new_field = self._fast_copy_empty()
else:
new_field = Field(domain=domain, codomain=codomain, dtype=dtype,
comm=comm, datamodel=datamodel,
field_type=field_type, **kwargs)
datamodel=datamodel, field_type=field_type,
**kwargs)
return new_field
def set_val(self, new_val=None, copy=False):
......@@ -554,6 +523,9 @@ class Field(object):
# Case 1: x is a distributed_data_object
if isinstance(x, distributed_data_object):
if x.comm is not self._comm:
raise ValueError(about._errors.cstring(
"ERROR: comms do not match."))
to_copy = False
# Check the shape
......@@ -609,7 +581,8 @@ class Field(object):
x = distributed_data_object(x,
global_shape=self.shape,
dtype=dtype,
distribution_strategy=self.datamodel)
distribution_strategy=self.datamodel,
comm=self._comm)
# Cast the d2o
return self.cast(x, dtype=dtype)
......
......@@ -116,9 +116,9 @@ class GLSpace(Space):
self.discrete = False
self.harmonic = False
self.distances = tuple(gl.vol(self.paradict['nlat'],
nlon=self.paradict['nlon']
).astype(np.float))
self.distances = (tuple(gl.vol(self.paradict['nlat'],
nlon=self.paradict['nlon']
).astype(np.float)),)
@property
def para(self):
......@@ -140,6 +140,10 @@ class GLSpace(Space):
def shape(self):
return (np.int((self.paradict['nlat'] * self.paradict['nlon'])),)
@property
def vol(self):
return np.sum(self.paradict['nlon'] * np.array(self.distances[0]))
@property
def meta_volume(self):
"""
......
......@@ -12,7 +12,8 @@ from nifty.config import about
class space_paradict(object):
def __init__(self, **kwargs):
self.parameters = {}
if not hasattr(self, 'parameters'):
self.parameters = {}
for key in kwargs:
self[key] = kwargs[key]
......@@ -56,7 +57,7 @@ class rg_space_paradict(space_paradict):
def __setitem__(self, key, arg):
if key not in ['shape', 'complexity', 'zerocenter']:
raise ValueError(about._errors.cstring(
"ERROR: Unsupported rg_space parameter"))
"ERROR: Unsupported RGSpace parameter:" + key))
if key == 'shape':
temp = np.array(arg, dtype=np.int).flatten()
......@@ -90,7 +91,7 @@ class lm_space_paradict(space_paradict):
def __setitem__(self, key, arg):
if key not in ['lmax', 'mmax']:
raise ValueError(about._errors.cstring(
"ERROR: Unsupported rg_space parameter"))
"ERROR: Unsupported LMSpace parameter: " + key))
if key == 'lmax':
temp = np.int(arg)
......@@ -135,7 +136,7 @@ class gl_space_paradict(space_paradict):
def __setitem__(self, key, arg):
if key not in ['nlat', 'nlon']:
raise ValueError(about._errors.cstring(
"ERROR: Unsupported rg_space parameter"))
"ERROR: Unsupported GLSpace parameter: " + key))
if key == 'nlat':
temp = int(arg)
......@@ -187,7 +188,7 @@ class hp_space_paradict(space_paradict):
class power_space_paradict(space_paradict):
def __init__(self, distribution_strategy, log, nbin, binbounds):
super(power_space_paradict, self).__init___(
space_paradict.__init__(self,
distribution_strategy=distribution_strategy,
log=log,
nbin=nbin,
......@@ -196,7 +197,7 @@ class power_space_paradict(space_paradict):
def __setitem__(self, key, arg):
if key not in ['distribution_strategy', 'log', 'nbin', 'binbounds']:
raise ValueError(about._errors.cstring(
"ERROR: Unsupported PowerSpace parameter"))
"ERROR: Unsupported PowerSpace parameter: " + key))
if key == 'log':
try:
......@@ -223,32 +224,38 @@ class power_space_paradict(space_paradict):
class rg_power_space_paradict(power_space_paradict, rg_space_paradict):
def __init__(self, shape, dgrid, zerocentered, log, nbin, binbounds):
rg_space_paradict.__init___(shape=shape,
dgrid=dgrid,
zerocentered=zerocentered,
log=log,
nbin=nbin,
binbounds=binbounds)
def __init__(self, shape, dgrid, zerocenter, distribution_strategy,
log, nbin, binbounds):
rg_space_paradict.__init__(self,
shape=shape,
complexity=0,
zerocenter=zerocenter)
power_space_paradict.__init__(
self,
distribution_strategy=distribution_strategy,
log=log,
nbin=nbin,
binbounds=binbounds)
self['dgrid'] = dgrid
def __setitem__(self, key, arg):
if key not in ['shape', 'zerocentered', 'distribution_strategy',
'log', 'nbin', 'binbounds']:
if key not in ['shape', 'complexity', 'zerocenter',
'distribution_strategy', 'log', 'nbin', 'binbounds',
'dgrid']:
raise ValueError(about._errors.cstring(
"ERROR: Unsupported RGPowerSpace parameter"))
"ERROR: Unsupported RGPowerSpace parameter: " + key))
if key in ['distribution_strategy', 'log', 'nbin', 'binbounds']:
power_space_paradict.__setitem__(key, arg)
power_space_paradict.__setitem__(self, key, arg)
elif key == 'dgrid':
temp = np.array(arg, dtype=np.float).flatten()
if np.size(temp) != self.ndim:
temp = temp * np.ones(self.ndim, dtype=np.float)
temp = tuple(temp)
if len(temp) != self.ndim:
raise ValueError(about._errors.cstring(
"ERROR: Number of dimensions does not match the init "
"value."))
self.parameters.__setitem__(key, temp)
else:
rg_space_paradict.__setitem__(key, arg)
rg_space_paradict.__setitem__(self, key, arg)
......
......@@ -3,3 +3,7 @@
from power_space import PowerSpace
from rg_power_space import RGPowerSpace
from lm_power_space import LMPowerSpace
from power_index_factory import PowerIndexFactory,\
RGPowerIndexFactory,\
LMPowerIndexFactory
\ No newline at end of file
# -*- coding: utf-8 -*-
from power_indices import PowerIndices,\
RGPowerIndices,\
LMPowerIndices
class _PowerIndexFactory(object):
def __init__(self):
self.power_indices_storage = {}
def _get_power_index_class(self):
return PowerIndices
def hash_arguments(self, **kwargs):
return frozenset(kwargs.items())
def get_power_indices(self, log, nbin, binbounds, **kwargs):
current_hash = self.hash_arguments(**kwargs)
if current_hash not in self.power_indices_storage:
power_class = self._get_power_index_class()
self.power_indices_storage[current_hash] = power_class(
log=log,
nbin=nbin,
binbounds=binbounds,
**kwargs)
power_indices = self.power_indices_storage[current_hash]
power_index = power_indices.get_index_dict(log=log,
nbin=nbin,
binbounds=binbounds)
return power_index
class _RGPowerIndexFactory(_PowerIndexFactory):
def _get_power_index_class(self):
return RGPowerIndices
class _LMPowerIndexFactory(_PowerIndexFactory):
def _get_power_index_class(self):
return LMPowerIndices
PowerIndexFactory = _PowerIndexFactory()
RGPowerIndexFactory = _RGPowerIndexFactory()
LMPowerIndexFactory = _LMPowerIndexFactory()
This diff is collapsed.
......@@ -18,6 +18,10 @@ class PowerSpace(Space):
# Here it would be time to initialize the power indices
raise NotImplementedError
self.distances = None
self.harmonic = True
def calculate_power_spectrum(self):
raise NotImplementedError
......
......@@ -4,22 +4,32 @@ import numpy as np
from nifty.power import PowerSpace
from nifty.nifty_paradict import rg_power_space_paradict
# from nifty.power.power_index_factory import RGPowerIndexFactory
from power_index_factory import RGPowerIndexFactory
class RGPowerSpace(PowerSpace):
def __init__(self, shape, dgrid, distribution_strategy, zerocentered=False,
dtype=np.dtype('float'), log=False, nbin=None,
binbounds=None):
def __init__(self, shape, dgrid, distribution_strategy,
dtype=np.dtype('float'), zerocenter=False,
log=False, nbin=None, binbounds=None):
self.dtype = np.dtype(dtype)
self.paradict = rg_power_space_paradict(
shape=shape,
dgrid=dgrid,
zerocentered=zerocentered,
zerocenter=zerocenter,
distribution_strategy=distribution_strategy,
log=log,
nbin=nbin,
binbounds=binbounds)
# self.power_indices = RGPowerIndexFactory.get_power_indices(
# **self.paradict.parameters)
temp_dict = self.paradict.parameters.copy()
del temp_dict['complexity']
self.power_indices = RGPowerIndexFactory.get_power_indices(**temp_dict)
self.distances = (tuple(self.power_indices['rho']),)
self.harmonic = True
self.discrete = False
def calculate_power_spectrum(self, x, axes=None):
fieldabs = abs(x)**2
# need a bincount with axes function here.
......@@ -163,25 +163,10 @@ class RGSpace(Space):
else:
self.dtype = np.dtype('complex128')
# set volume/distances
naxes = len(self.paradict['shape'])
if distances is None:
distances = 1 / np.array(self.paradict['shape'], dtype=np.float)
elif np.isscalar(distances):
distances = np.ones(naxes, dtype=np.float) * distances
else:
distances = np.array(distances, dtype=np.float)
if np.size(distances) == 1:
distances = distances * np.ones(naxes, dtype=np.float)
if np.size(distances) != naxes:
raise ValueError(about._errors.cstring(
"ERROR: size mismatch ( " + str(np.size(distances)) +
" <> " + str(naxes) + " )."))
if np.any(distances <= 0):
raise ValueError(about._errors.cstring(
"ERROR: nonpositive distance(s)."))
distances = 1 / np.array(self.shape, dtype=np.float)
self.distances = tuple(distances)
self.distances = distances
self.harmonic = bool(harmonic)
self.discrete = False
......
......@@ -220,9 +220,27 @@ class Space(object):
"WARNING: incompatible dtype: " + str(dtype)))
self.dtype = dtype
self.discrete = True
# self.harmonic = False
self.distances = (np.float(1),)
self.discrete = None
self.harmonic = None
self._distances = None
@property
def distances(self):
return self._distances
@distances.setter
def distances(self, distances):
naxes = len(self.shape)
if np.isscalar(distances):
distances = tuple(np.ones(naxes, dtype=np.float) * distances)
elif not isinstance(distances, tuple):
distances = tuple(distances)
if len(distances) != naxes:
raise ValueError(about._errors.cstring(
"ERROR: size mismatch ( " + str(np.size(distances)) +
" <> " + str(naxes) + " )."))
self._distances = distances
@property
def para(self):
......@@ -321,8 +339,9 @@ class Space(object):
return dof
@property
def vol(self, split=False):
return reduce(lambda x, y: x * y, self.distances)
def vol(self):
collapsed_distances = [np.sum(x) for x in self.distances]
return reduce(lambda x, y: x * y, collapsed_distances)
@property
def vol_split(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment