Commit cfcdd7d8 authored by theos's avatar theos

Modified nifty_field.py to successfully initialize with multiple spaces and a field_type.

Therefore made the Field_type base class more advanced.
Added default field dtype to nifty configuration.
parent 4ff7b806
......@@ -50,7 +50,8 @@ from nifty_paradict import space_paradict,\
point_space_paradict,\
nested_space_paradict
from field_types import Field_array
from field_types import Field_type,\
Field_array
from operators import *
......
# -*- coding: utf-8 -*-
import os
import numpy as np
import keepers
# Setup the dependency injector
......@@ -50,6 +50,21 @@ variable_verbosity = keepers.Variable('verbosity',
lambda z: z == abs(int(z)),
genus='int')
def _dtype_validator(dtype):
try:
np.dtype(dtype)
except(TypeError):
return False
else:
return True
variable_default_field_dtype = keepers.Variable(
'default_field_dtype',
['float'],
_dtype_validator,
genus='str')
variable_default_datamodel = keepers.Variable(
'default_datamodel',
['fftw', 'equal'],
......@@ -64,6 +79,7 @@ nifty_configuration = keepers.get_Configuration(
variable_use_healpy,
variable_use_libsharp,
variable_verbosity,
variable_default_field_dtype,
variable_default_datamodel,
],
path=os.path.expanduser('~') + "/.nifty/nifty_config")
......
# -*- coding: utf-8 -*-
from field_array import Field_array
\ No newline at end of file
from field_array import Field_type,\
Field_array
\ No newline at end of file
# -*- coding: utf-8 -*-
import numpy as np
class Base_field_type(object):
def __init__(self, shape):
self.shape = shape
@property
def shape(self):
return self._shape
@shape.setter
def shape(self, shape):
class Field_type(object):
def __init__(self, shape, dtype):
try:
new_shape = tuple([int(i) for i in shape])
except TypeError:
new_shape = (int(shape), )
self._shape = new_shape
self._dtype = np.dtype(dtype)
@property
def shape(self):
return self._shape
@property
def dtype(self):
return self._dtype
def get_dof(self, split=False):
if issubclass(self.dtype.type, np.complexfloating):
multiplicator = 2
else:
multiplicator = 1
if split:
dof = tuple(multiplicator*np.array(self.shape))
else:
dof = multiplicator*reduce(lambda x, y: x*y, self.shape)
return dof
def process(self, method_name, array, inplace=True, **kwargs):
try:
result_array = self.__getattr__(method_name)(array,
......@@ -29,3 +46,6 @@ class Base_field_type(object):
result_array = array.copy()
return result_array
def complement_cast(self, x, axis=None):
return x
# -*- coding: utf-8 -*-
from base_field_type import Base_field_type
from base_field_type import Field_type
class Field_array(Base_field_type):
class Field_array(Field_type):
pass
......@@ -311,7 +311,7 @@ class space(object):
raise NotImplementedError(about._errors.cstring(
"ERROR: no generic instance method 'dof'."))
def _complement_cast(self, x, axis=None):
def complement_cast(self, x, axis=None):
return x
# TODO: Move enforce power into power_indices class
......
......@@ -9,6 +9,9 @@ from nifty.config import about, \
nifty_configuration as gc, \
dependency_injector as gdi
from nifty.field_types import Field_type,\
Field_array
from nifty.nifty_core import space
import nifty.nifty_utilities as utilities
......@@ -102,10 +105,9 @@ class field(object):
"""
def __init__(self, domain=None, val=None, codomain=None,
def __init__(self, domain, val=None, codomain=None,
dtype=None, field_type=None, copy=False,
datamodel=gc['default_datamodel'], comm=gc['default_comm'],
**kwargs):
datamodel=None, comm=None, **kwargs):
"""
Sets the attributes for a field class instance.
......@@ -137,6 +139,7 @@ class field(object):
comm=comm,
copy=copy,
dtype=dtype,
field_type=field_type,
datamodel=datamodel,
**kwargs)
else:
......@@ -146,11 +149,12 @@ class field(object):
comm=comm,
copy=copy,
dtype=dtype,
field_type=field_type,
datamodel=datamodel,
**kwargs)
def _init_from_field(self, f, domain, codomain, comm, copy, dtype,
datamodel, **kwargs):
field_type, datamodel, **kwargs):
# check domain
if domain is None:
domain = f.domain
......@@ -179,37 +183,45 @@ class field(object):
**kwargs)
def _init_from_array(self, val, domain, codomain, comm, copy, dtype,
datamodel, **kwargs):
if dtype is None:
dtype = self._get_dtype_from_domain(domain)
self.dtype = dtype
self.comm = self._parse_comm(comm)
# if val is a distributed data object, we take it's datamodel,
# since we don't want to redistribute large amounts of data, if not
# necessary
if isinstance(val, distributed_data_object):
if datamodel != val.distribution_strategy:
about.warnings.cprint("WARNING: datamodel set to val's "
"datamodel.")
datamodel = val.distribution_strategy
if datamodel not in DISTRIBUTION_STRATEGIES['global']:
about.warnings.cprint("WARNING: datamodel set to default.")
self.datamodel = \
gc['default_distribution_strategy']
else:
self.datamodel = datamodel
field_type, datamodel, **kwargs):
# check domain
self.domain = self._check_valid_domain(domain=domain)
self._axis_list = self._get_axis_list_from_domain(domain=domain)
self.domain = self._parse_domain(domain=domain)
self._axis_list = self._get_axis_list_from_domain(domain=self.domain)
# check codomain
if codomain is None:
codomain = self.get_codomain(domain=domain)
elif not self._check_codomain(domain=domain, codomain=codomain):
codomain = self.get_codomain(domain=self.domain)
else:
self.codomain = self._parse_codomain(codomain, self.domain)
if field_type is None:
field_type = Field_array(shape=(), dtype=np.float)
elif not isinstance(field_type, Field_type):
raise ValueError(about._errors.cstring(
"ERROR: The given codomain is not compatible to the domain."))
self.codomain = codomain
"ERROR: The given field_type object is not an "
"instance of nifty.Field_type."))
self.field_type = field_type
if dtype is None:
dtype = self._infer_dtype(domain=self.domain,
dtype=dtype,
field_type=self.field_type)
self.dtype = dtype
if comm is not None:
self.comm = self._parse_comm(comm)
elif isinstance(val, distributed_data_object):
self.comm = val.comm
else:
self.comm = gc['default_comm']
if datamodel in DISTRIBUTION_STRATEGIES['all']:
self.datamodel = datamodel
elif isinstance(val, distributed_data_object):
self.datamodel = val.distribution_strategy
else:
self.datamodel = gc['default_datamodel']
if val is None:
if kwargs == {}:
......@@ -220,16 +232,17 @@ class field(object):
**kwargs)
self.set_val(new_val=val, copy=copy)
def _get_dtype_from_domain(self, domain=None):
if domain is None:
domain = self.domain
dtype_tuple = tuple(np.dtype(space.dtype) for space in domain)
def _infer_dtype(self, domain=None, dtype=None, field_type=None):
dtype_tuple = (np.dtype(gc['default_field_dtype']),)
if domain is not None:
dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
if field_type is not None:
dtype_tuple += (field_type.dtype,)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
return dtype
def _get_axis_list_from_domain(self, domain=None):
if domain is None:
domain = self.domain
def _get_axis_list_from_domain(self, domain):
i = 0
axis_list = []
for sp in domain:
......@@ -259,35 +272,35 @@ class field(object):
"default-MPI-module's Intracomm Class."))
return result_comm
def _check_valid_domain(self, domain):
def _parse_domain(self, domain):
if not isinstance(domain, tuple):
raise TypeError(about._errors.cstring(
"ERROR: The given domain is not a list."))
domain = (domain,)
for d in domain:
if not isinstance(d, space):
raise TypeError(about._errors.cstring(
"ERROR: Given domain is not a space."))
elif d.dtype > self.dtype:
raise AttributeError(about._errors.cstring(
"ERROR: The dtype of a space in the domain is larger than "
"the field's dtype."))
"ERROR: Given domain contains something that is not a "
"nifty.space."))
return domain
def _check_codomain(self, domain, codomain):
if codomain is None:
return False
if len(domain) == len(codomain):
return np.all(map((lambda d, c: d._check_codomain(c)), domain,
codomain))
else:
return False
def _parse_codomain(self, codomain, domain):
if not isinstance(codomain, tuple):
codomain = (codomain,)
if len(domain) != len(codomain):
raise ValueError(about._errors.cstring(
"ERROR: domain and codomain do not have the same length."))
for (cd, d) in zip(codomain, domain):
if not isinstance(cd, space):
raise TypeError(about._errors.cstring(
"ERROR: Given codomain contains something that is not a"
"nifty.space."))
if not d.check_codomain(cd):
raise ValueError(about._errors.cstring(
"ERROR: codomain contains a space that is not compatible "
"to its domain-counterpart."))
return codomain
def get_codomain(self, domain):
if len(domain) == 1:
return (domain[0].get_codomain(),)
else:
codomain = tuple(space.get_codomain() for space in domain)
self.codomain = codomain
codomain = tuple(sp.get_codomain() for sp in domain)
return codomain
def get_random_values(self, domain=None, codomain=None, **kwargs):
......@@ -297,10 +310,13 @@ class field(object):
def __len__(self):
return int(self.get_dim()[0])
def copy(self, domain=None, codomain=None, **kwargs):
def copy(self, domain=None, codomain=None, field_type=None, **kwargs):
copied_val = self._unary_operation(self.get_val(), op='copy', **kwargs)
new_field = self.copy_empty(domain=domain, codomain=codomain)
new_field.set_val(new_val=copied_val)
new_field = self.copy_empty(domain=domain,
codomain=codomain,
field_type=field_type)
new_field.set_val(new_val=copied_val,
copy=True)
return new_field
def _fast_copy_empty(self):
......@@ -318,28 +334,41 @@ class field(object):
return new_field
def copy_empty(self, domain=None, codomain=None, dtype=None, comm=None,
datamodel=None, **kwargs):
datamodel=None, field_type=None, **kwargs):
if domain is None:
domain = self.domain
if codomain is None:
codomain = self.codomain
if dtype is None:
dtype = self.dtype
if comm is None:
comm = self.comm
if datamodel is None:
datamodel = self.datamodel
if (domain is self.domain and
codomain is self.codomain and
dtype == self.dtype and
comm == self.comm and
if field_type is None:
field_type = self.field_type
_fast_copyable = True
for i in len(self.domain):
if self.domain[i] is not domain[i]:
_fast_copyable = False
break
if self.codomain[i] is not codomain[i]:
_fast_copyable = False
break
if (_fast_copyable and dtype == self.dtype and comm == self.comm and
datamodel == self.datamodel and
kwargs == {}):
field_type is self.field_type and kwargs == {}):
new_field = self._fast_copy_empty()
else:
new_field = field(domain=domain, codomain=codomain, dtype=dtype,
comm=comm, datamodel=datamodel, **kwargs)
comm=comm, datamodel=datamodel,
field_type=field_type, **kwargs)
return new_field
def set_val(self, new_val=None, copy=False):
......@@ -352,34 +381,28 @@ class field(object):
New field values either as a constant or an arbitrary array.
"""
if new_val is not None:
new_val = self.cast(new_val)
if copy:
new_val = self.unary_operation(new_val, op='copy')
self.val = self.cast(new_val)
self.val = new_val
return self.val
def get_val(self):
return self.val
# TODO: Add functionality for boolean indexing.
def __getitem__(self, key):
return self.val[key]
def __setitem__(self, key, item):
self.val[key] = item
def get_shape(self):
if len(self.domain) > 1:
shape_tuple = tuple(space.get_shape() for space in self.domain)
@property
def shape(self):
shape_tuple = tuple(sp.get_shape() for sp in self.domain)
shape_tuple += (self.field_type.shape, )
global_shape = reduce(lambda x, y: x + y, shape_tuple)
else:
global_shape = self.domain[0].get_shape()
if isinstance(global_shape, tuple):
return global_shape
else:
return ()
def get_dim(self):
"""
......@@ -401,19 +424,25 @@ class field(object):
return np.prod(self.get_shape())
def get_dof(self, split=False):
dim = self.get_dim()
if np.issubdtype(self.dtype, np.complex):
return 2 * dim
dof_tuple = tuple(sp.get_dof(split=split) for sp in self.domain)
dof_tuple += (self.field_type.get_dof(split=split),)
if split:
return reduce(lambda x, y: x + y, dof_tuple)
else:
return dim
return reduce(lambda x, y: x * y, dof_tuple)
def cast(self, x=None, dtype=None):
if dtype is not None:
dtype = np.dtype(dtype)
if dtype is None:
dtype = self.dtype
else:
dtype = np.dtype(dtype)
casted_x = self._cast_to_d2o(x, dtype=dtype)
return self._complement_cast(casted_x)
for ind, sp in enumerate(self.domain):
casted_x = sp.complement_cast(casted_x, axis=self._axis_list[ind])
casted_x = self.field_type.complement_cast(casted_x)
return casted_x
def _cast_to_d2o(self, x, dtype=None, shape=None, **kwargs):
"""
......@@ -508,49 +537,6 @@ class field(object):
# Cast the d2o
return self.cast(x, dtype=dtype)
def _complement_cast(self, x):
for ind, space in enumerate(self.domain):
space._complement_cast(x, axis=self._axis_list[ind])
return x
def set_domain(self, new_domain=None, force=False):
"""
Resets the codomain of the field.
Parameters
----------
new_codomain : space
The new space wherein the transform of the field should live.
(default=None).
"""
# check codomain
if new_domain is None:
new_domain = self.codomain.get_codomain()
elif not force:
assert (self.codomain.check_codomain(new_domain))
self.domain = new_domain
return self.domain
def set_codomain(self, new_codomain=None, force=False):
"""
Resets the codomain of the field.
Parameters
----------
new_codomain : space
The new space wherein the transform of the field should live.
(default=None).
"""
# check codomain
if new_codomain is None:
new_codomain = self.domain.get_codomain()
elif not force:
assert (self.domain.check_codomain(new_codomain))
self.codomain = new_codomain
return self.codomain
def weight(self, new_val=None, power=1, overwrite=False, spaces=None):
"""
Returns the field values, weighted with the volume factors to a
......
......@@ -33,7 +33,8 @@ setup(name="ift_nifty",
description="Numerical Information Field Theory",
url="http://www.mpa-garching.mpg.de/ift/nifty/",
packages=["nifty", "nifty.demos", "nifty.rg", "nifty.lm",
"nifty.operators", "nifty.dummys", "nifty.config"],
"nifty.operators", "nifty.dummys", "nifty.field_types",
"nifty.config"],
package_dir={"nifty": ""},
zip_safe=False,
dependency_links=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment