Commit 27d9a668 authored by Jait Dixit's avatar Jait Dixit
Browse files

Merge branch 'master' into tests

parents 8b1c3069 fa788a4e
...@@ -13,7 +13,7 @@ before_script: ...@@ -13,7 +13,7 @@ before_script:
- apt-get update - apt-get update
- > - >
apt-get install -y build-essential python python-pip python-dev git apt-get install -y build-essential python python-pip python-dev git
gfortran autoconf gsl-bin libgsl-dev gfortran autoconf gsl-bin libgsl-dev wget
- pip install -r ci/requirements_base.txt - pip install -r ci/requirements_base.txt
- chmod +x ci/*.sh - chmod +x ci/*.sh
...@@ -58,9 +58,10 @@ test_mpi_fftw_hdf5: ...@@ -58,9 +58,10 @@ test_mpi_fftw_hdf5:
libfftw3-quad3 libfftw3-single3 libfftw3-quad3 libfftw3-single3
- > - >
apt-get install -y libhdf5-10 libhdf5-dev libhdf5-openmpi-10 apt-get install -y libhdf5-10 libhdf5-dev libhdf5-openmpi-10
libhdf5-openmpi-dev hdf5-tools python-h5py libhdf5-openmpi-dev hdf5-tools
- pip install astropy healpy mpi4py - pip install astropy healpy mpi4py
- pip install git+https://github.com/mrbell/gfft - pip install git+https://github.com/mrbell/gfft
- ci/install_h5py.sh
- ci/install_libsharp.sh - ci/install_libsharp.sh
- ci/install_pyfftw.sh - ci/install_pyfftw.sh
- python setup.py build_ext --inplace - python setup.py build_ext --inplace
......
#!/bin/bash
wget https://api.github.com/repos/h5py/h5py/tags -O - | grep tarball_url | grep -v rc | head -n 1 | cut -d '"' -f 4 | wget -i - -O h5py.tar.gz
tar xzf h5py.tar.gz
cd h5py-h5py*
export CC=mpicc
export HDF5_DIR=/usr/lib/x86_64-linux-gnu/hdf5/openmpi
python setup.py configure --mpi
python setup.py build
python setup.py install
cd ..
rm -r h5py-h5py*
rm h5py.tar.gz
from nifty import * from nifty import *
import plotly.offline as pl #import plotly.offline as pl
import plotly.graph_objs as go #import plotly.graph_objs as go
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
......
# -*- coding: utf-8 -*-
import abc
import numpy as np
from keepers import Loggable,\
Versionable
class DomainObject(Versionable, Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, dtype):
self._dtype = np.dtype(dtype)
self._ignore_for_hash = []
def __hash__(self):
# Extract the identifying parts from the vars(self) dict.
result_hash = 0
for key in sorted(vars(self).keys()):
item = vars(self)[key]
if key in self._ignore_for_hash or key == '_ignore_for_hash':
continue
result_hash ^= item.__hash__() ^ int(hash(key)/117)
return result_hash
def __eq__(self, x):
if isinstance(x, type(self)):
return hash(self) == hash(x)
else:
return False
def __ne__(self, x):
return not self.__eq__(x)
@property
def dtype(self):
return self._dtype
@abc.abstractproperty
def shape(self):
raise NotImplementedError(
"There is no generic shape for DomainObject.")
@abc.abstractproperty
def dim(self):
raise NotImplementedError(
"There is no generic dim for DomainObject.")
def pre_cast(self, x, axes=None):
return x
def post_cast(self, x, axes=None):
return x
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name
return None
@classmethod
def _from_hdf5(cls, hdf5_group, repository):
result = cls(dtype=np.dtype(hdf5_group.attrs['dtype']))
return result
...@@ -13,6 +13,14 @@ class Energy(object): ...@@ -13,6 +13,14 @@ class Energy(object):
def at(self, position): def at(self, position):
return self.__class__(position) return self.__class__(position)
@property
def position(self):
return self._position
@position.setter
def position(self, position):
self._position = position
@property @property
def value(self): def value(self):
raise NotImplementedError raise NotImplementedError
......
...@@ -9,9 +9,8 @@ from d2o import distributed_data_object,\ ...@@ -9,9 +9,8 @@ from d2o import distributed_data_object,\
from nifty.config import nifty_configuration as gc from nifty.config import nifty_configuration as gc
from nifty.field_types import FieldType from nifty.domain_object import DomainObject
from nifty.spaces.space import Space
from nifty.spaces.power_space import PowerSpace from nifty.spaces.power_space import PowerSpace
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
...@@ -21,25 +20,15 @@ from nifty.random import Random ...@@ -21,25 +20,15 @@ from nifty.random import Random
class Field(Loggable, Versionable, object): class Field(Loggable, Versionable, object):
# ---Initialization methods--- # ---Initialization methods---
def __init__(self, domain=None, val=None, dtype=None, field_type=None, def __init__(self, domain=None, val=None, dtype=None,
distribution_strategy=None, copy=False): distribution_strategy=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val) self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain) self.domain_axes = self._get_axes_tuple(self.domain)
self.field_type = self._parse_field_type(field_type, val=val)
try:
start = len(reduce(lambda x, y: x+y, self.domain_axes))
except TypeError:
start = 0
self.field_type_axes = self._get_axes_tuple(self.field_type,
start=start)
self.dtype = self._infer_dtype(dtype=dtype, self.dtype = self._infer_dtype(dtype=dtype,
val=val, val=val,
domain=self.domain, domain=self.domain)
field_type=self.field_type)
self.distribution_strategy = self._parse_distribution_strategy( self.distribution_strategy = self._parse_distribution_strategy(
distribution_strategy=distribution_strategy, distribution_strategy=distribution_strategy,
...@@ -53,34 +42,18 @@ class Field(Loggable, Versionable, object): ...@@ -53,34 +42,18 @@ class Field(Loggable, Versionable, object):
domain = val.domain domain = val.domain
else: else:
domain = () domain = ()
elif isinstance(domain, Space): elif isinstance(domain, DomainObject):
domain = (domain,) domain = (domain,)
elif not isinstance(domain, tuple): elif not isinstance(domain, tuple):
domain = tuple(domain) domain = tuple(domain)
for d in domain: for d in domain:
if not isinstance(d, Space): if not isinstance(d, DomainObject):
raise TypeError( raise TypeError(
"Given domain contains something that is not a " "Given domain contains something that is not a "
"nifty.space.") "DomainObject instance.")
return domain return domain
def _parse_field_type(self, field_type, val=None):
if field_type is None:
if isinstance(val, Field):
field_type = val.field_type
else:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
def _get_axes_tuple(self, things_with_shape, start=0): def _get_axes_tuple(self, things_with_shape, start=0):
i = start i = start
axes_list = [] axes_list = []
...@@ -92,7 +65,7 @@ class Field(Loggable, Versionable, object): ...@@ -92,7 +65,7 @@ class Field(Loggable, Versionable, object):
axes_list += [tuple(l)] axes_list += [tuple(l)]
return tuple(axes_list) return tuple(axes_list)
def _infer_dtype(self, dtype, val, domain, field_type): def _infer_dtype(self, dtype, val, domain):
if dtype is None: if dtype is None:
if isinstance(val, Field) or \ if isinstance(val, Field) or \
isinstance(val, distributed_data_object): isinstance(val, distributed_data_object):
...@@ -102,8 +75,6 @@ class Field(Loggable, Versionable, object): ...@@ -102,8 +75,6 @@ class Field(Loggable, Versionable, object):
dtype_tuple = (np.dtype(dtype),) dtype_tuple = (np.dtype(dtype),)
if domain is not None: if domain is not None:
dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain) dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
if field_type is not None:
dtype_tuple += tuple(np.dtype(ft.dtype) for ft in field_type)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple) dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
...@@ -127,10 +98,10 @@ class Field(Loggable, Versionable, object): ...@@ -127,10 +98,10 @@ class Field(Loggable, Versionable, object):
# ---Factory methods--- # ---Factory methods---
@classmethod @classmethod
def from_random(cls, random_type, domain=None, dtype=None, field_type=None, def from_random(cls, random_type, domain=None, dtype=None,
distribution_strategy=None, **kwargs): distribution_strategy=None, **kwargs):
# create a initially empty field # create a initially empty field
f = cls(domain=domain, dtype=dtype, field_type=field_type, f = cls(domain=domain, dtype=dtype,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
# now use the processed input in terms of f in order to parse the # now use the processed input in terms of f in order to parse the
...@@ -259,7 +230,9 @@ class Field(Loggable, Versionable, object): ...@@ -259,7 +230,9 @@ class Field(Loggable, Versionable, object):
result_domain = list(self.domain) result_domain = list(self.domain)
result_domain[space_index] = power_domain result_domain[space_index] = power_domain
result_field = self.copy_empty(domain=result_domain) result_field = self.copy_empty(
domain=result_domain,
distribution_strategy=power_spectrum.distribution_strategy)
result_field.set_val(new_val=power_spectrum, copy=False) result_field.set_val(new_val=power_spectrum, copy=False)
return result_field return result_field
...@@ -361,7 +334,6 @@ class Field(Loggable, Versionable, object): ...@@ -361,7 +334,6 @@ class Field(Loggable, Versionable, object):
std=std, std=std,
domain=result_domain, domain=result_domain,
dtype=harmonic_domain.dtype, dtype=harmonic_domain.dtype,
field_type=self.field_type,
distribution_strategy=self.distribution_strategy) distribution_strategy=self.distribution_strategy)
for x in result_list] for x in result_list]
...@@ -449,9 +421,7 @@ class Field(Loggable, Versionable, object): ...@@ -449,9 +421,7 @@ class Field(Loggable, Versionable, object):
@property @property
def shape(self): def shape(self):
shape_tuple = () shape_tuple = tuple(sp.shape for sp in self.domain)
shape_tuple += tuple(sp.shape for sp in self.domain)
shape_tuple += tuple(ft.shape for ft in self.field_type)
try: try:
global_shape = reduce(lambda x, y: x + y, shape_tuple) global_shape = reduce(lambda x, y: x + y, shape_tuple)
except TypeError: except TypeError:
...@@ -461,9 +431,7 @@ class Field(Loggable, Versionable, object): ...@@ -461,9 +431,7 @@ class Field(Loggable, Versionable, object):
@property @property
def dim(self): def dim(self):
dim_tuple = () dim_tuple = tuple(sp.dim for sp in self.domain)
dim_tuple += tuple(sp.dim for sp in self.domain)
dim_tuple += tuple(ft.dim for ft in self.field_type)
try: try:
return reduce(lambda x, y: x * y, dim_tuple) return reduce(lambda x, y: x * y, dim_tuple)
except TypeError: except TypeError:
...@@ -498,20 +466,12 @@ class Field(Loggable, Versionable, object): ...@@ -498,20 +466,12 @@ class Field(Loggable, Versionable, object):
casted_x = sp.pre_cast(casted_x, casted_x = sp.pre_cast(casted_x,
axes=self.domain_axes[ind]) axes=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type):
casted_x = ft.pre_cast(casted_x,
axes=self.field_type_axes[ind])
casted_x = self._actual_cast(casted_x, dtype=dtype) casted_x = self._actual_cast(casted_x, dtype=dtype)
for ind, sp in enumerate(self.domain): for ind, sp in enumerate(self.domain):
casted_x = sp.post_cast(casted_x, casted_x = sp.post_cast(casted_x,
axes=self.domain_axes[ind]) axes=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type):
casted_x = ft.post_cast(casted_x,
axes=self.field_type_axes[ind])
return casted_x return casted_x
def _actual_cast(self, x, dtype=None): def _actual_cast(self, x, dtype=None):
...@@ -528,19 +488,16 @@ class Field(Loggable, Versionable, object): ...@@ -528,19 +488,16 @@ class Field(Loggable, Versionable, object):
return_x.set_full_data(x, copy=False) return_x.set_full_data(x, copy=False)
return return_x return return_x
def copy(self, domain=None, dtype=None, field_type=None, def copy(self, domain=None, dtype=None, distribution_strategy=None):
distribution_strategy=None):
copied_val = self.get_val(copy=True) copied_val = self.get_val(copy=True)
new_field = self.copy_empty( new_field = self.copy_empty(
domain=domain, domain=domain,
dtype=dtype, dtype=dtype,
field_type=field_type,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
new_field.set_val(new_val=copied_val, copy=False) new_field.set_val(new_val=copied_val, copy=False)
return new_field return new_field
def copy_empty(self, domain=None, dtype=None, field_type=None, def copy_empty(self, domain=None, dtype=None, distribution_strategy=None):
distribution_strategy=None):
if domain is None: if domain is None:
domain = self.domain domain = self.domain
else: else:
...@@ -551,11 +508,6 @@ class Field(Loggable, Versionable, object): ...@@ -551,11 +508,6 @@ class Field(Loggable, Versionable, object):
else: else:
dtype = np.dtype(dtype) dtype = np.dtype(dtype)
if field_type is None:
field_type = self.field_type
else:
field_type = self._parse_field_type(field_type)
if distribution_strategy is None: if distribution_strategy is None:
distribution_strategy = self.distribution_strategy distribution_strategy = self.distribution_strategy
...@@ -565,10 +517,6 @@ class Field(Loggable, Versionable, object): ...@@ -565,10 +517,6 @@ class Field(Loggable, Versionable, object):
if self.domain[i] is not domain[i]: if self.domain[i] is not domain[i]:
fast_copyable = False fast_copyable = False
break break
for i in xrange(len(self.field_type)):
if self.field_type[i] is not field_type[i]:
fast_copyable = False
break
except IndexError: except IndexError:
fast_copyable = False fast_copyable = False
...@@ -578,7 +526,6 @@ class Field(Loggable, Versionable, object): ...@@ -578,7 +526,6 @@ class Field(Loggable, Versionable, object):
else: else:
new_field = Field(domain=domain, new_field = Field(domain=domain,
dtype=dtype, dtype=dtype,
field_type=field_type,
distribution_strategy=distribution_strategy) distribution_strategy=distribution_strategy)
return new_field return new_field
...@@ -624,8 +571,6 @@ class Field(Loggable, Versionable, object): ...@@ -624,8 +571,6 @@ class Field(Loggable, Versionable, object):
assert len(x.domain) == len(self.domain) assert len(x.domain) == len(self.domain)
for index in xrange(len(self.domain)): for index in xrange(len(self.domain)):
assert x.domain[index] == self.domain[index] assert x.domain[index] == self.domain[index]
for index in xrange(len(self.field_type)):
assert x.field_type[index] == self.field_type[index]
except AssertionError: except AssertionError:
raise ValueError( raise ValueError(
"domains are incompatible.") "domains are incompatible.")
...@@ -705,22 +650,15 @@ class Field(Loggable, Versionable, object): ...@@ -705,22 +650,15 @@ class Field(Loggable, Versionable, object):
return_field.set_val(new_val, copy=False) return_field.set_val(new_val, copy=False)
return return_field return return_field
def _contraction_helper(self, op, spaces, types): def _contraction_helper(self, op, spaces):
# build a list of all axes # build a list of all axes
if spaces is None: if spaces is None:
spaces = xrange(len(self.domain)) spaces = xrange(len(self.domain))
else: else:
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain)) spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
if types is None: axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces)
types = xrange(len(self.field_type))
else:
types = utilities.cast_axis_to_tuple(types, len(self.field_type))
axes_list = ()
axes_list += tuple(self.domain_axes[sp_index] for sp_index in spaces)
axes_list += tuple(self.field_type_axes[ft_index] for
ft_index in types)
try: try:
axes_list = reduce(lambda x, y: x+y, axes_list) axes_list = reduce(lambda x, y: x+y, axes_list)
except TypeError: except TypeError:
...@@ -737,47 +675,44 @@ class Field(Loggable, Versionable, object): ...@@ -737,47 +675,44 @@ class Field(Loggable, Versionable, object):
return_domain = tuple(self.domain[i] return_domain = tuple(self.domain[i]
for i in xrange(len(self.domain)) for i in xrange(len(self.domain))
if i not in spaces) if i not in spaces)
return_field_type = tuple(self.field_type[i]
for i in xrange(len(self.field_type))
if i not in types)
return_field = Field(domain=return_domain, return_field = Field(domain=return_domain,
val=data, val=data,
field_type=return_field_type,
copy=False) copy=False)
return return_field return return_field
def sum(self, spaces=None, types=None): def sum(self, spaces=None):
return self._contraction_helper('sum', spaces, types) return self._contraction_helper('sum', spaces)
def prod(self, spaces=None, types=None): def prod(self, spaces=None):
return self._contraction_helper('prod', spaces, types) return self._contraction_helper('prod', spaces)
def all(self, spaces=None, types=None): def all(self, spaces=None):
return self._contraction_helper('all', spaces, types) return self._contraction_helper('all', spaces)
def any(self, spaces=None, types=None): def any(self, spaces=None):
return self._contraction_helper('any', spaces, types) return self._contraction_helper('any', spaces)
def min(self, spaces=None, types=None): def min(self, spaces=None):
return self._contraction_helper('min', spaces, types) return self._contraction_helper('min', spaces)
def nanmin(self, spaces=None, types=None): def nanmin(self, spaces=None):
return self._contraction_helper('nanmin', spaces, types) return self._contraction_helper('nanmin', spaces)
def max(self, spaces=None, types=None): def max(self, spaces=None):
return self._contraction_helper('max', spaces, types) return self._contraction_helper('max', spaces)
def nanmax(self, spaces=None, types=None): def nanmax(self, spaces=None):
return self._contraction_helper('nanmax', spaces, types) return self._contraction_helper('nanmax', spaces)
def mean(self, spaces=None, types=None): def mean(self, spaces=None):
return self._contraction_helper('mean', spaces, types) return self._contraction_helper('mean', spaces)
def var(self, spaces=None, types=None): def var(self, spaces=None):
return self._contraction_helper('var', spaces, types) return self._contraction_helper('var', spaces)
def std(self, spaces=None, types=None): def std(self, spaces=None):
return self._contraction_helper('std', spaces, types) return self._contraction_helper('std', spaces)
# ---General binary methods--- # ---General binary methods---
...@@ -788,9 +723,6 @@ class Field(Loggable, Versionable, object): ...@@ -788,9 +723,6 @@ class Field(Loggable, Versionable, object):
assert len(other.domain) == len(self.domain) assert len(other.domain) == len(self.domain)
for index in xrange(len(self.domain)): for index in xrange(len(self.domain)):
assert other.domain[index] == self.domain[index] assert other.domain[index] == self.domain[index]
assert len(other.field_type) == len(self.field_type)
for index in xrange(len(self.field_type)):
assert other.field_type[index] == self.field_type[index]
except AssertionError: except AssertionError:
raise ValueError( raise ValueError(
"domains are incompatible.") "domains are incompatible.")
...@@ -893,19 +825,14 @@ class Field(Loggable, Versionable, object): ...@@ -893,19 +825,14 @@ class Field(Loggable, Versionable, object):
def _to_hdf5(self, hdf5_group): def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name hdf5_group.attrs['dtype'] = self.dtype.name
hdf5_group.attrs