Commit 27d9a668 authored by Jait Dixit's avatar Jait Dixit
Browse files

Merge branch 'master' into tests

parents 8b1c3069 fa788a4e
......@@ -13,7 +13,7 @@ before_script:
- apt-get update
- >
apt-get install -y build-essential python python-pip python-dev git
gfortran autoconf gsl-bin libgsl-dev
gfortran autoconf gsl-bin libgsl-dev wget
- pip install -r ci/requirements_base.txt
- chmod +x ci/*.sh
......@@ -58,9 +58,10 @@ test_mpi_fftw_hdf5:
libfftw3-quad3 libfftw3-single3
- >
apt-get install -y libhdf5-10 libhdf5-dev libhdf5-openmpi-10
libhdf5-openmpi-dev hdf5-tools python-h5py
libhdf5-openmpi-dev hdf5-tools
- pip install astropy healpy mpi4py
- pip install git+https://github.com/mrbell/gfft
- ci/install_h5py.sh
- ci/install_libsharp.sh
- ci/install_pyfftw.sh
- python setup.py build_ext --inplace
......
#!/bin/bash
wget https://api.github.com/repos/h5py/h5py/tags -O - | grep tarball_url | grep -v rc | head -n 1 | cut -d '"' -f 4 | wget -i - -O h5py.tar.gz
tar xzf h5py.tar.gz
cd h5py-h5py*
export CC=mpicc
export HDF5_DIR=/usr/lib/x86_64-linux-gnu/hdf5/openmpi
python setup.py configure --mpi
python setup.py build
python setup.py install
cd ..
rm -r h5py-h5py*
rm h5py.tar.gz
from nifty import *
import plotly.offline as pl
import plotly.graph_objs as go
#import plotly.offline as pl
#import plotly.graph_objs as go
from mpi4py import MPI
comm = MPI.COMM_WORLD
......
# -*- coding: utf-8 -*-
import abc
import numpy as np
from keepers import Loggable,\
Versionable
class DomainObject(Versionable, Loggable, object):
__metaclass__ = abc.ABCMeta
def __init__(self, dtype):
self._dtype = np.dtype(dtype)
self._ignore_for_hash = []
def __hash__(self):
# Extract the identifying parts from the vars(self) dict.
result_hash = 0
for key in sorted(vars(self).keys()):
item = vars(self)[key]
if key in self._ignore_for_hash or key == '_ignore_for_hash':
continue
result_hash ^= item.__hash__() ^ int(hash(key)/117)
return result_hash
def __eq__(self, x):
if isinstance(x, type(self)):
return hash(self) == hash(x)
else:
return False
def __ne__(self, x):
return not self.__eq__(x)
@property
def dtype(self):
return self._dtype
@abc.abstractproperty
def shape(self):
raise NotImplementedError(
"There is no generic shape for DomainObject.")
@abc.abstractproperty
def dim(self):
raise NotImplementedError(
"There is no generic dim for DomainObject.")
def pre_cast(self, x, axes=None):
return x
def post_cast(self, x, axes=None):
return x
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name
return None
@classmethod
def _from_hdf5(cls, hdf5_group, repository):
result = cls(dtype=np.dtype(hdf5_group.attrs['dtype']))
return result
......@@ -13,6 +13,14 @@ class Energy(object):
def at(self, position):
return self.__class__(position)
@property
def position(self):
return self._position
@position.setter
def position(self, position):
self._position = position
@property
def value(self):
raise NotImplementedError
......
......@@ -9,9 +9,8 @@ from d2o import distributed_data_object,\
from nifty.config import nifty_configuration as gc
from nifty.field_types import FieldType
from nifty.domain_object import DomainObject
from nifty.spaces.space import Space
from nifty.spaces.power_space import PowerSpace
import nifty.nifty_utilities as utilities
......@@ -21,25 +20,15 @@ from nifty.random import Random
class Field(Loggable, Versionable, object):
# ---Initialization methods---
def __init__(self, domain=None, val=None, dtype=None, field_type=None,
def __init__(self, domain=None, val=None, dtype=None,
distribution_strategy=None, copy=False):
self.domain = self._parse_domain(domain=domain, val=val)
self.domain_axes = self._get_axes_tuple(self.domain)
self.field_type = self._parse_field_type(field_type, val=val)
try:
start = len(reduce(lambda x, y: x+y, self.domain_axes))
except TypeError:
start = 0
self.field_type_axes = self._get_axes_tuple(self.field_type,
start=start)
self.dtype = self._infer_dtype(dtype=dtype,
val=val,
domain=self.domain,
field_type=self.field_type)
domain=self.domain)
self.distribution_strategy = self._parse_distribution_strategy(
distribution_strategy=distribution_strategy,
......@@ -53,34 +42,18 @@ class Field(Loggable, Versionable, object):
domain = val.domain
else:
domain = ()
elif isinstance(domain, Space):
elif isinstance(domain, DomainObject):
domain = (domain,)
elif not isinstance(domain, tuple):
domain = tuple(domain)
for d in domain:
if not isinstance(d, Space):
if not isinstance(d, DomainObject):
raise TypeError(
"Given domain contains something that is not a "
"nifty.space.")
"DomainObject instance.")
return domain
def _parse_field_type(self, field_type, val=None):
if field_type is None:
if isinstance(val, Field):
field_type = val.field_type
else:
field_type = ()
elif isinstance(field_type, FieldType):
field_type = (field_type,)
elif not isinstance(field_type, tuple):
field_type = tuple(field_type)
for ft in field_type:
if not isinstance(ft, FieldType):
raise TypeError(
"Given object is not a nifty.FieldType.")
return field_type
def _get_axes_tuple(self, things_with_shape, start=0):
i = start
axes_list = []
......@@ -92,7 +65,7 @@ class Field(Loggable, Versionable, object):
axes_list += [tuple(l)]
return tuple(axes_list)
def _infer_dtype(self, dtype, val, domain, field_type):
def _infer_dtype(self, dtype, val, domain):
if dtype is None:
if isinstance(val, Field) or \
isinstance(val, distributed_data_object):
......@@ -102,8 +75,6 @@ class Field(Loggable, Versionable, object):
dtype_tuple = (np.dtype(dtype),)
if domain is not None:
dtype_tuple += tuple(np.dtype(sp.dtype) for sp in domain)
if field_type is not None:
dtype_tuple += tuple(np.dtype(ft.dtype) for ft in field_type)
dtype = reduce(lambda x, y: np.result_type(x, y), dtype_tuple)
......@@ -127,10 +98,10 @@ class Field(Loggable, Versionable, object):
# ---Factory methods---
@classmethod
def from_random(cls, random_type, domain=None, dtype=None, field_type=None,
def from_random(cls, random_type, domain=None, dtype=None,
distribution_strategy=None, **kwargs):
# create a initially empty field
f = cls(domain=domain, dtype=dtype, field_type=field_type,
f = cls(domain=domain, dtype=dtype,
distribution_strategy=distribution_strategy)
# now use the processed input in terms of f in order to parse the
......@@ -259,7 +230,9 @@ class Field(Loggable, Versionable, object):
result_domain = list(self.domain)
result_domain[space_index] = power_domain
result_field = self.copy_empty(domain=result_domain)
result_field = self.copy_empty(
domain=result_domain,
distribution_strategy=power_spectrum.distribution_strategy)
result_field.set_val(new_val=power_spectrum, copy=False)
return result_field
......@@ -361,7 +334,6 @@ class Field(Loggable, Versionable, object):
std=std,
domain=result_domain,
dtype=harmonic_domain.dtype,
field_type=self.field_type,
distribution_strategy=self.distribution_strategy)
for x in result_list]
......@@ -449,9 +421,7 @@ class Field(Loggable, Versionable, object):
@property
def shape(self):
shape_tuple = ()
shape_tuple += tuple(sp.shape for sp in self.domain)
shape_tuple += tuple(ft.shape for ft in self.field_type)
shape_tuple = tuple(sp.shape for sp in self.domain)
try:
global_shape = reduce(lambda x, y: x + y, shape_tuple)
except TypeError:
......@@ -461,9 +431,7 @@ class Field(Loggable, Versionable, object):
@property
def dim(self):
dim_tuple = ()
dim_tuple += tuple(sp.dim for sp in self.domain)
dim_tuple += tuple(ft.dim for ft in self.field_type)
dim_tuple = tuple(sp.dim for sp in self.domain)
try:
return reduce(lambda x, y: x * y, dim_tuple)
except TypeError:
......@@ -498,20 +466,12 @@ class Field(Loggable, Versionable, object):
casted_x = sp.pre_cast(casted_x,
axes=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type):
casted_x = ft.pre_cast(casted_x,
axes=self.field_type_axes[ind])
casted_x = self._actual_cast(casted_x, dtype=dtype)
for ind, sp in enumerate(self.domain):
casted_x = sp.post_cast(casted_x,
axes=self.domain_axes[ind])
for ind, ft in enumerate(self.field_type):
casted_x = ft.post_cast(casted_x,
axes=self.field_type_axes[ind])
return casted_x
def _actual_cast(self, x, dtype=None):
......@@ -528,19 +488,16 @@ class Field(Loggable, Versionable, object):
return_x.set_full_data(x, copy=False)
return return_x
def copy(self, domain=None, dtype=None, field_type=None,
distribution_strategy=None):
def copy(self, domain=None, dtype=None, distribution_strategy=None):
copied_val = self.get_val(copy=True)
new_field = self.copy_empty(
domain=domain,
dtype=dtype,
field_type=field_type,
distribution_strategy=distribution_strategy)
new_field.set_val(new_val=copied_val, copy=False)
return new_field
def copy_empty(self, domain=None, dtype=None, field_type=None,
distribution_strategy=None):
def copy_empty(self, domain=None, dtype=None, distribution_strategy=None):
if domain is None:
domain = self.domain
else:
......@@ -551,11 +508,6 @@ class Field(Loggable, Versionable, object):
else:
dtype = np.dtype(dtype)
if field_type is None:
field_type = self.field_type
else:
field_type = self._parse_field_type(field_type)
if distribution_strategy is None:
distribution_strategy = self.distribution_strategy
......@@ -565,10 +517,6 @@ class Field(Loggable, Versionable, object):
if self.domain[i] is not domain[i]:
fast_copyable = False
break
for i in xrange(len(self.field_type)):
if self.field_type[i] is not field_type[i]:
fast_copyable = False
break
except IndexError:
fast_copyable = False
......@@ -578,7 +526,6 @@ class Field(Loggable, Versionable, object):
else:
new_field = Field(domain=domain,
dtype=dtype,
field_type=field_type,
distribution_strategy=distribution_strategy)
return new_field
......@@ -624,8 +571,6 @@ class Field(Loggable, Versionable, object):
assert len(x.domain) == len(self.domain)
for index in xrange(len(self.domain)):
assert x.domain[index] == self.domain[index]
for index in xrange(len(self.field_type)):
assert x.field_type[index] == self.field_type[index]
except AssertionError:
raise ValueError(
"domains are incompatible.")
......@@ -705,22 +650,15 @@ class Field(Loggable, Versionable, object):
return_field.set_val(new_val, copy=False)
return return_field
def _contraction_helper(self, op, spaces, types):
def _contraction_helper(self, op, spaces):
# build a list of all axes
if spaces is None:
spaces = xrange(len(self.domain))
else:
spaces = utilities.cast_axis_to_tuple(spaces, len(self.domain))
if types is None:
types = xrange(len(self.field_type))
else:
types = utilities.cast_axis_to_tuple(types, len(self.field_type))
axes_list = tuple(self.domain_axes[sp_index] for sp_index in spaces)
axes_list = ()
axes_list += tuple(self.domain_axes[sp_index] for sp_index in spaces)
axes_list += tuple(self.field_type_axes[ft_index] for
ft_index in types)
try:
axes_list = reduce(lambda x, y: x+y, axes_list)
except TypeError:
......@@ -737,47 +675,44 @@ class Field(Loggable, Versionable, object):
return_domain = tuple(self.domain[i]
for i in xrange(len(self.domain))
if i not in spaces)
return_field_type = tuple(self.field_type[i]
for i in xrange(len(self.field_type))
if i not in types)
return_field = Field(domain=return_domain,
val=data,
field_type=return_field_type,
copy=False)
return return_field
def sum(self, spaces=None, types=None):
return self._contraction_helper('sum', spaces, types)
def sum(self, spaces=None):
return self._contraction_helper('sum', spaces)
def prod(self, spaces=None, types=None):
return self._contraction_helper('prod', spaces, types)
def prod(self, spaces=None):
return self._contraction_helper('prod', spaces)
def all(self, spaces=None, types=None):
return self._contraction_helper('all', spaces, types)
def all(self, spaces=None):
return self._contraction_helper('all', spaces)
def any(self, spaces=None, types=None):
return self._contraction_helper('any', spaces, types)
def any(self, spaces=None):
return self._contraction_helper('any', spaces)
def min(self, spaces=None, types=None):
return self._contraction_helper('min', spaces, types)
def min(self, spaces=None):
return self._contraction_helper('min', spaces)
def nanmin(self, spaces=None, types=None):
return self._contraction_helper('nanmin', spaces, types)
def nanmin(self, spaces=None):
return self._contraction_helper('nanmin', spaces)
def max(self, spaces=None, types=None):
return self._contraction_helper('max', spaces, types)
def max(self, spaces=None):
return self._contraction_helper('max', spaces)
def nanmax(self, spaces=None, types=None):
return self._contraction_helper('nanmax', spaces, types)
def nanmax(self, spaces=None):
return self._contraction_helper('nanmax', spaces)
def mean(self, spaces=None, types=None):
return self._contraction_helper('mean', spaces, types)
def mean(self, spaces=None):
return self._contraction_helper('mean', spaces)
def var(self, spaces=None, types=None):
return self._contraction_helper('var', spaces, types)
def var(self, spaces=None):
return self._contraction_helper('var', spaces)
def std(self, spaces=None, types=None):
return self._contraction_helper('std', spaces, types)
def std(self, spaces=None):
return self._contraction_helper('std', spaces)
# ---General binary methods---
......@@ -788,9 +723,6 @@ class Field(Loggable, Versionable, object):
assert len(other.domain) == len(self.domain)
for index in xrange(len(self.domain)):
assert other.domain[index] == self.domain[index]
assert len(other.field_type) == len(self.field_type)
for index in xrange(len(self.field_type)):
assert other.field_type[index] == self.field_type[index]
except AssertionError:
raise ValueError(
"domains are incompatible.")
......@@ -893,19 +825,14 @@ class Field(Loggable, Versionable, object):
def _to_hdf5(self, hdf5_group):
hdf5_group.attrs['dtype'] = self.dtype.name
hdf5_group.attrs['distribution_strategy'] = self.distribution_strategy
hdf5_group.attrs['field_type_axes'] = str(self.field_type_axes)
hdf5_group.attrs['domain_axes'] = str(self.domain_axes)
hdf5_group['num_domain'] = len(self.domain)
hdf5_group['num_ft'] = len(self.field_type)
ret_dict = {'val': self.val}
for i in range(len(self.domain)):
ret_dict['s_' + str(i)] = self.domain[i]
for i in range(len(self.field_type)):
ret_dict['ft_' + str(i)] = self.field_type[i]
return ret_dict
@classmethod
......@@ -920,14 +847,7 @@ class Field(Loggable, Versionable, object):
temp_domain.append(repository.get('s_' + str(i), hdf5_group))
new_field.domain = tuple(temp_domain)
temp_ft = []
for i in range(hdf5_group['num_ft'][()]):
temp_domain.append(repository.get('ft_' + str(i), hdf5_group))
new_field.field_type = tuple(temp_ft)
exec('new_field.domain_axes = ' + hdf5_group.attrs['domain_axes'])
exec('new_field.field_type_axes = ' +
hdf5_group.attrs['field_type_axes'])
new_field._val = repository.get('val', hdf5_group)
new_field.dtype = np.dtype(hdf5_group.attrs['dtype'])
new_field.distribution_strategy =\
......
# -*- coding: utf-8 -*-
import pickle
from field_type import FieldType
class FieldArray(FieldType):
def __init__(self, dtype, shape):
try:
new_shape = tuple([int(i) for i in shape])
except TypeError:
new_shape = (int(shape), )
self._shape = new_shape
super(FieldArray, self).__init__(dtype=dtype)
@property
def shape(self):
return self._shape
@property
def dim(self):
return reduce(lambda x, y: x*y, self.shape)
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group['shape'] = self.shape
hdf5_group['dtype'] = pickle.dumps(self.dtype)
return None
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
result = cls(
hdf5_group['shape'][:],
pickle.loads(hdf5_group['dtype'][()])
)
return result
# -*- coding: utf-8 -*-
import pickle
import numpy as np
from keepers import Versionable
from nifty.domain_object import DomainObject
class FieldType(Versionable, object):
def __init__(self, shape, dtype):
try:
new_shape = tuple([int(i) for i in shape])
except TypeError:
new_shape = (int(shape), )
self._shape = new_shape
self._dtype = np.dtype(dtype)
def __hash__(self):
# Extract the identifying parts from the vars(self) dict.
result_hash = 0
for (key, item) in vars(self).items():
result_hash ^= item.__hash__() ^ int(hash(key)/117)
return result_hash
def __eq__(self, x):
if isinstance(x, type(self)):
return hash(self) == hash(x)
else:
return False
@property
def shape(self):
return self._shape
@property
def dtype(self):
return self._dtype
@property
def dim(self):
raise NotImplementedError
class FieldType(DomainObject):
def process(self, method_name, array, inplace=True, **kwargs):
try:
......@@ -52,25 +17,3 @@ class FieldType(Versionable, object):
result_array = array.copy()
return result_array
def pre_cast(self, x, axes=None):
return x
def post_cast(self, x, axes=None):
return x
# ---Serialization---
def _to_hdf5(self, hdf5_group):
hdf5_group['shape'] = self.shape
hdf5_group['dtype'] = pickle.dumps(self.dtype)
return None
@classmethod
def _from_hdf5(cls, hdf5_group, loopback_get):
result = cls(
hdf5_group['shape'][:],
pickle.loads(hdf5_group['dtype'][()])
)
return result
......@@ -31,6 +31,7 @@ class LineSearch(Loggable, object):
self.pk = None
self.line_energy = None
self.f_k_minus_1 = None
self.prefered_initial_step_size = None