diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
new file mode 100644
index 0000000000000000000000000000000000000000..83c2166e819f3e8ad0ca3713734ae386d0abe1ca
--- /dev/null
+++ b/.gitlab-ci.yml
@@ -0,0 +1,35 @@
+image: python:2.7-wheezy
+before_script:
+ - pip install nose
+ - pip install nose_parameterized
+ - pip install numpy
+ - chmod +x continuous_integration/*
+
+test_minimal:
+ script:
+ - python setup.py install
+ - nosetests
+
+test_mpi:
+ script:
+ - continuous_integration/install_mpi.sh
+ - python setup.py install
+ - nosetests
+ - mpirun -n 2 nosetests
+ - mpirun -n 5 nosetests
+
+test_mpi_fftw:
+ script:
+ - continuous_integration/install_mpi.sh
+ - continuous_integration/install_fftw.sh
+ - python setup.py install
+ - mpirun -n 2 nosetests
+
+test_mpi_fftw_hdf5:
+ script:
+ - continuous_integration/install_mpi.sh
+ - continuous_integration/install_fftw.sh
+ - continuous_integration/install_h5py.sh
+ - python setup.py install
+ - mpirun -n 2 nosetests
+
diff --git a/continuous_integration/install_fftw.sh b/continuous_integration/install_fftw.sh
new file mode 100644
index 0000000000000000000000000000000000000000..8ca10b0665832dcf4346240ac4d9e7d3ef7ede9f
--- /dev/null
+++ b/continuous_integration/install_fftw.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+pip install cython
+
+apt-get install -y libfftw3-3 libfftw3-bin libfftw3-dev libfftw3-mpi-dev libfftw3-mpi3
+git clone -b mpi https://github.com/fredRos/pyFFTW.git
+cd pyFFTW/
+#export LDFLAGS="-L/usr/include"
+#export CFLAGS="-I/usr/include"
+CC=mpicc python setup.py build_ext install
+cd ..
diff --git a/continuous_integration/install_h5py.sh b/continuous_integration/install_h5py.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e6529bbc14f7a25cadb9fdadcc2d5f567162c531
--- /dev/null
+++ b/continuous_integration/install_h5py.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+pip install cython
+
+apt-get update -qy
+apt-get install -y libhdf5-openmpi-dev
+
+curl -s https://api.github.com/repos/h5py/h5py/tags | grep tarball_url | head -n 1 | cut -d '"' -f 4 | wget -i - -O h5py.tar.gz
+tar xzf h5py.tar.gz
+cd h5py-h5py*
+export CC=mpicc
+python setup.py configure --mpi
+python setup.py build
+python setup.py install
+cd ..
diff --git a/continuous_integration/install_mpi.sh b/continuous_integration/install_mpi.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e468a744867b7e56d243f6cf9235eb22991ce07a
--- /dev/null
+++ b/continuous_integration/install_mpi.sh
@@ -0,0 +1,5 @@
+#!/bin/bash
+
+apt-get update -qy
+apt-get install -y openmpi-bin libopenmpi-dev
+pip install mpi4py
diff --git a/d2o/config/d2o_config.py b/d2o/config/d2o_config.py
index dfd4f968a4c12556f1376ffe6d38bf3f2b1a7e9c..1059fed25b08dc9e1ba50d7de391a7189c4c74de 100644
--- a/d2o/config/d2o_config.py
+++ b/d2o/config/d2o_config.py
@@ -22,7 +22,7 @@ import keepers
dependency_injector = keepers.DependencyInjector(
['h5py',
('mpi4py.MPI', 'MPI'),
- ('d2o.mpi_dummy.mpi_dummy', 'MPI_dummy')]
+ ('mpi_dummy', 'MPI_dummy')]
)
dependency_injector.register('pyfftw', lambda z: hasattr(z, 'FFTW_MPI'))
diff --git a/d2o/distributed_data_object.py b/d2o/distributed_data_object.py
index 06e6ce844042604017f2fb733dc763665fb7f6be..fa17585c8709fa5ac5dc0c21b3ac264856fc4805 100644
--- a/d2o/distributed_data_object.py
+++ b/d2o/distributed_data_object.py
@@ -1282,6 +1282,7 @@ class distributed_data_object(Versionable, object):
('index', np.dtype('float'))])
local_argmax_list = np.sort(local_argmax_list,
order=['value', 'index'])
+
# take the last entry here and correct the minus sign of the index
return -np.int(local_argmax_list[-1][1])
diff --git a/d2o/distributor_factory.py b/d2o/distributor_factory.py
index a413add232b86f1f71f6691e442483314497f3f0..baf106e2f55283b29cc5b061452de8eb387ae505 100644
--- a/d2o/distributor_factory.py
+++ b/d2o/distributor_factory.py
@@ -1916,86 +1916,72 @@ class _slicing_distributor(distributor):
else:
return local_data.reshape(temp_local_shape)
- if 'h5py' in gdi:
- def save_data(self, data, alias, path=None, overwriteQ=True):
- comm = self.comm
- h5py_parallel = h5py.get_config().mpi
- if comm.size > 1 and not h5py_parallel:
- raise RuntimeError("ERROR: Programm is run with MPI " +
- "size > 1 but non-parallel version of " +
- "h5py is loaded.")
- # if no path and therefore no filename was given, use the alias
- # as filename
- use_path = alias if path is None else path
-
- # create the file-handle
- if h5py_parallel and gc['mpi_module'] == 'MPI':
- f = h5py.File(use_path, 'a', driver='mpio', comm=comm)
- else:
- f = h5py.File(use_path, 'a')
- # check if dataset with name == alias already exists
- try:
- f[alias]
- # if yes, and overwriteQ is set to False, raise an Error
- if overwriteQ is False:
- raise ValueError(about_cstring(
- "ERROR: overwriteQ is False, but alias already " +
- "in use!"))
- else: # if yes, remove the existing dataset
- del f[alias]
- except(KeyError):
- pass
-
- # create dataset
- dset = f.create_dataset(alias,
- shape=self.global_shape,
- dtype=self.dtype)
- # write the data
- dset[self.local_start:self.local_end] = data
- # close the file
- f.close()
-
- def load_data(self, alias, path):
- comm = self.comm
- # parse the path
- file_path = path if (path is not None) else alias
- # create the file-handle
- if h5py.get_config().mpi and gc['mpi_module'] == 'MPI':
- f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
- else:
- f = h5py.File(file_path, 'r')
- dset = f[alias]
- # check shape
- if dset.shape != self.global_shape:
- raise TypeError(about_cstring(
- "ERROR: The shape of the given dataset does not match " +
- "the distributed_data_object."))
- # check dtype
- if dset.dtype != self.dtype:
- raise TypeError(about_cstring(
- "ERROR: The datatype of the given dataset does not " +
- "match the one of the distributed_data_object."))
- # if everything seems to fit, load the data
- data = dset[self.local_start:self.local_end]
- # close the file
- f.close()
- return data
+ def save_data(self, data, alias, path=None, overwriteQ=True):
+ comm = self.comm
+ h5py_parallel = h5py.get_config().mpi
+ if comm.size > 1 and not h5py_parallel:
+ raise RuntimeError("ERROR: Programm is run with MPI " +
+ "size > 1 but non-parallel version of " +
+ "h5py is loaded.")
+ # if no path and therefore no filename was given, use the alias
+ # as filename
+ use_path = alias if path is None else path
+
+ # create the file-handle
+ if h5py_parallel and gc['mpi_module'] == 'MPI':
+ f = h5py.File(use_path, 'a', driver='mpio', comm=comm)
+ else:
+ f = h5py.File(use_path, 'a')
+ # check if dataset with name == alias already exists
+ try:
+ f[alias]
+ # if yes, and overwriteQ is set to False, raise an Error
+ if overwriteQ is False:
+ raise ValueError(about_cstring(
+ "ERROR: overwriteQ is False, but alias already " +
+ "in use!"))
+ else: # if yes, remove the existing dataset
+ del f[alias]
+ except(KeyError):
+ pass
- def _data_to_hdf5(self, hdf5_dataset, data):
- hdf5_dataset[self.local_start:self.local_end] = data
+ # create dataset
+ dset = f.create_dataset(alias,
+ shape=self.global_shape,
+ dtype=self.dtype)
+ # write the data
+ dset[self.local_start:self.local_end] = data
+ # close the file
+ f.close()
- else:
- def save_data(self, *args, **kwargs):
- raise ImportError(about_cstring(
- "ERROR: h5py is not available"))
-
- def load_data(self, *args, **kwargs):
- raise ImportError(about_cstring(
- "ERROR: h5py is not available"))
+ def load_data(self, alias, path):
+ comm = self.comm
+ # parse the path
+ file_path = path if (path is not None) else alias
+ # create the file-handle
+ if h5py.get_config().mpi and gc['mpi_module'] == 'MPI':
+ f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
+ else:
+ f = h5py.File(file_path, 'r')
+ dset = f[alias]
+ # check shape
+ if dset.shape != self.global_shape:
+ raise TypeError(about_cstring(
+ "ERROR: The shape of the given dataset does not match " +
+ "the distributed_data_object."))
+ # check dtype
+ if dset.dtype != self.dtype:
+ raise TypeError(about_cstring(
+ "ERROR: The datatype of the given dataset does not " +
+ "match the one of the distributed_data_object."))
+ # if everything seems to fit, load the data
+ data = dset[self.local_start:self.local_end]
+ # close the file
+ f.close()
+ return data
- def _data_to_hdf5(self, *args, **kwargs):
- raise ImportError(about_cstring(
- "ERROR: h5py is not available"))
+ def _data_to_hdf5(self, hdf5_dataset, data):
+ hdf5_dataset[self.local_start:self.local_end] = data
def get_iter(self, d2o):
return d2o_slicing_iter(d2o)
@@ -2284,88 +2270,74 @@ class _not_distributor(distributor):
a = obj.get_local_data(copy=False)
return np.searchsorted(a=a, v=v, side=side)
- if 'h5py' in gdi:
- def save_data(self, data, alias, path=None, overwriteQ=True):
- comm = self.comm
- h5py_parallel = h5py.get_config().mpi
- if comm.size > 1 and not h5py_parallel:
- raise RuntimeError("ERROR: Programm is run with MPI " +
- "size > 1 but non-parallel version of " +
- "h5py is loaded.")
- # if no path and therefore no filename was given, use the alias
- # as filename
- use_path = alias if path is None else path
-
- # create the file-handle
- if h5py_parallel and gc['mpi_module'] == 'MPI':
- f = h5py.File(use_path, 'a', driver='mpio', comm=comm)
- else:
- f = h5py.File(use_path, 'a')
- # check if dataset with name == alias already exists
- try:
- f[alias]
- # if yes, and overwriteQ is set to False, raise an Error
- if overwriteQ is False:
- raise ValueError(about_cstring(
- "ERROR: overwriteQ == False, but alias already " +
- "in use!"))
- else: # if yes, remove the existing dataset
- del f[alias]
- except(KeyError):
- pass
-
- # create dataset
- dset = f.create_dataset(alias,
- shape=self.global_shape,
- dtype=self.dtype)
- # write the data
- if comm.rank == 0:
- dset[:] = data
- # close the file
- f.close()
-
- def load_data(self, alias, path):
- comm = self.comm
- # parse the path
- file_path = path if (path is not None) else alias
- # create the file-handle
- if h5py.get_config().mpi and gc['mpi_module'] == 'MPI':
- f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
- else:
- f = h5py.File(file_path, 'r')
- dset = f[alias]
- # check shape
- if dset.shape != self.global_shape:
- raise TypeError(about_cstring(
- "ERROR: The shape of the given dataset does not match " +
- "the distributed_data_object."))
- # check dtype
- if dset.dtype != self.dtype:
- raise TypeError(about_cstring(
- "ERROR: The datatype of the given dataset does not " +
- "match the distributed_data_object."))
- # if everything seems to fit, load the data
- data = dset[:]
- # close the file
- f.close()
- return data
-
- def _data_to_hdf5(self, hdf5_dataset, data):
- if self.comm.rank == 0:
- hdf5_dataset[:] = data
-
- else:
- def save_data(self, *args, **kwargs):
- raise ImportError(about_cstring(
- "ERROR: h5py is not available"))
+ def save_data(self, data, alias, path=None, overwriteQ=True):
+ comm = self.comm
+ h5py_parallel = h5py.get_config().mpi
+ if comm.size > 1 and not h5py_parallel:
+ raise RuntimeError("ERROR: Programm is run with MPI " +
+ "size > 1 but non-parallel version of " +
+ "h5py is loaded.")
+ # if no path and therefore no filename was given, use the alias
+ # as filename
+ use_path = alias if path is None else path
+
+ # create the file-handle
+ if h5py_parallel and gc['mpi_module'] == 'MPI':
+ f = h5py.File(use_path, 'a', driver='mpio', comm=comm)
+ else:
+ f = h5py.File(use_path, 'a')
+ # check if dataset with name == alias already exists
+ try:
+ f[alias]
+ # if yes, and overwriteQ is set to False, raise an Error
+ if overwriteQ is False:
+ raise ValueError(about_cstring(
+ "ERROR: overwriteQ == False, but alias already " +
+ "in use!"))
+ else: # if yes, remove the existing dataset
+ del f[alias]
+ except(KeyError):
+ pass
- def load_data(self, *args, **kwargs):
- raise ImportError(about_cstring(
- "ERROR: h5py is not available"))
+ # create dataset
+ dset = f.create_dataset(alias,
+ shape=self.global_shape,
+ dtype=self.dtype)
+ # write the data
+ if comm.rank == 0:
+ dset[:] = data
+ # close the file
+ f.close()
+
+ def load_data(self, alias, path):
+ comm = self.comm
+ # parse the path
+ file_path = path if (path is not None) else alias
+ # create the file-handle
+ if h5py.get_config().mpi and gc['mpi_module'] == 'MPI':
+ f = h5py.File(file_path, 'r', driver='mpio', comm=comm)
+ else:
+ f = h5py.File(file_path, 'r')
+ dset = f[alias]
+ # check shape
+ if dset.shape != self.global_shape:
+ raise TypeError(about_cstring(
+ "ERROR: The shape of the given dataset does not match " +
+ "the distributed_data_object."))
+ # check dtype
+ if dset.dtype != self.dtype:
+ raise TypeError(about_cstring(
+ "ERROR: The datatype of the given dataset does not " +
+ "match the distributed_data_object."))
+ # if everything seems to fit, load the data
+ data = dset[:]
+ # close the file
+ f.close()
+ return data
- def _data_to_hdf5(self, *args, **kwargs):
- raise ImportError(about_cstring(
- "ERROR: h5py is not available"))
+ def _data_to_hdf5(self, hdf5_dataset, data):
+ if self.comm.rank == 0:
+ hdf5_dataset[:] = data
def get_iter(self, d2o):
return d2o_not_iter(d2o)
diff --git a/d2o/mpi_dummy/__init__.py b/d2o/mpi_dummy/__init__.py
deleted file mode 100644
index a4ff6e793a707f7fb8f033696f3ec7db780fa5f7..0000000000000000000000000000000000000000
--- a/d2o/mpi_dummy/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# D2O
-# Copyright (C) 2016 Theo Steininger
-#
-# Author: Theo Steininger
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-
-import mpi_dummy
\ No newline at end of file
diff --git a/d2o/mpi_dummy/mpi_dummy.py b/d2o/mpi_dummy/mpi_dummy.py
deleted file mode 100644
index 104f07539a074edf00d47341146cc296ae6292c8..0000000000000000000000000000000000000000
--- a/d2o/mpi_dummy/mpi_dummy.py
+++ /dev/null
@@ -1,175 +0,0 @@
-# D2O
-# Copyright (C) 2016 Theo Steininger
-#
-# Author: Theo Steininger
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program. If not, see .
-
-import copy
-import numpy as np
-
-
-class Op(object):
- @classmethod
- def Create(cls, function, commute=False):
- pass
-
-MIN = Op()
-MAX = Op()
-SUM = Op()
-PROD = Op()
-LAND = Op()
-LOR = Op()
-BAND = Op()
-BOR = Op()
-
-
-class Comm(object):
- pass
-
-
-class Intracomm(Comm):
- def __init__(self, name):
- if not running_single_threadedQ():
- raise RuntimeError("ERROR: MPI_dummy module is running in a " +
- "mpirun with n>1.")
- self.name = name
- self.rank = 0
- self.size = 1
-
- def Get_rank(self):
- return self.rank
-
- def Get_size(self):
- return self.size
-
- def _scattergather_helper(self, sendbuf, recvbuf=None, **kwargs):
- sendbuf = self._unwrapper(sendbuf)
- recvbuf = self._unwrapper(recvbuf)
- if recvbuf is not None:
- recvbuf[:] = sendbuf
- return recvbuf
- else:
- recvbuf = np.copy(sendbuf)
- return recvbuf
-
- def bcast(self, sendbuf, *args, **kwargs):
- return sendbuf
-
- def Bcast(self, sendbuf, *args, **kwargs):
- return sendbuf
-
- def scatter(self, sendbuf, *args, **kwargs):
- return sendbuf[0]
-
- def Scatter(self, *args, **kwargs):
- return self._scattergather_helper(*args, **kwargs)
-
- def Scatterv(self, *args, **kwargs):
- return self._scattergather_helper(*args, **kwargs)
-
- def gather(self, sendbuf, *args, **kwargs):
- return [sendbuf]
-
- def Gather(self, *args, **kwargs):
- return self._scattergather_helper(*args, **kwargs)
-
- def Gatherv(self, *args, **kwargs):
- return self._scattergather_helper(*args, **kwargs)
-
- def allgather(self, sendbuf, *args, **kwargs):
- return [sendbuf]
-
- def Allgather(self, *args, **kwargs):
- return self._scattergather_helper(*args, **kwargs)
-
- def Allgatherv(self, *args, **kwargs):
- return self._scattergather_helper(*args, **kwargs)
-
- def Allreduce(self, sendbuf, recvbuf, op, **kwargs):
- sendbuf = self._unwrapper(sendbuf)
- recvbuf = self._unwrapper(recvbuf)
- recvbuf[:] = sendbuf
- return recvbuf
-
- def allreduce(self, sendobj, op=SUM, **kwargs):
- if np.isscalar(sendobj):
- return sendobj
- return copy.copy(sendobj)
-
- def sendrecv(self, sendobj, **kwargs):
- return sendobj
-
- def _unwrapper(self, x):
- if isinstance(x, list):
- return x[0]
- else:
- return x
-
- def Barrier(self):
- pass
-
-
-class _datatype():
- def __init__(self, name):
- self.name = str(name)
-
-
-def running_single_threadedQ():
- try:
- from mpi4py import MPI
- except ImportError:
- return True
- else:
- if MPI.COMM_WORLD.size != 1:
- return False
- else:
- return True
-
-
-BYTE = _datatype('MPI_BYTE')
-SHORT = _datatype('MPI_SHORT')
-UNSIGNED_SHORT = _datatype("MPI_UNSIGNED_SHORT")
-UNSIGNED_INT = _datatype("MPI_UNSIGNED_INT")
-INT = _datatype("MPI_INT")
-LONG = _datatype("MPI_LONG")
-UNSIGNED_LONG = _datatype("MPI_UNSIGNED_LONG")
-LONG_LONG = _datatype("MPI_LONG_LONG")
-UNSIGNED_LONG_LONG = _datatype("MPI_UNSIGNED_LONG_LONG")
-FLOAT = _datatype("MPI_FLOAT")
-DOUBLE = _datatype("MPI_DOUBLE")
-LONG_DOUBLE = _datatype("MPI_LONG_DOUBLE")
-COMPLEX = _datatype("MPI_COMPLEX")
-DOUBLE_COMPLEX = _datatype("MPI_DOUBLE_COMPLEX")
-
-
-class _comm_wrapper(Intracomm):
- def __init__(self, name):
- self.cache = None
- self.name = name
- self.size = 1
- self.rank = 0
-
- @property
- def comm(self):
- if self.cache is None:
- self.cache = Intracomm(self.name)
- return self.cache
-
- def __getattr__(self, x):
- return self.comm.__getattribute__(x)
-
-
-COMM_WORLD = _comm_wrapper('MPI_dummy_COMM_WORLD')
-#COMM_WORLD.__class__ = COMM_WORLD.comm.__class__
diff --git a/setup.py b/setup.py
index df347ef3ea5d461085f715abe92b6f73eac2e7c2..f651ecff00c7af2008285b3c4b2e14c111041fbd 100644
--- a/setup.py
+++ b/setup.py
@@ -37,11 +37,12 @@ setup(
"computing in Python"),
keywords = "parallelization, numerics, MPI",
url = "https://gitlab.mpcdf.mpg.de/ift/D2O",
- packages=['d2o', 'd2o.config', 'd2o.mpi_dummy', 'test'],
+ packages=['d2o', 'd2o.config', 'test'],
zip_safe=False,
dependency_links = [
- "git+https://gitlab.mpcdf.mpg.de/ift/keepers.git#egg=keepers-0.3.4"],
- install_requires=['keepers>=0.3.4'],
+ "git+https://gitlab.mpcdf.mpg.de/ift/keepers.git#egg=keepers-0.3.4",
+ "git+https://gitlab.mpcdf.mpg.de/ift/mpi_dummy.git#egg=mpi_dummy-1.0.0"],
+ install_requires=['keepers>=0.3.4', 'mpi_dummy>=1.0.0'],
long_description=read('README.rst'),
license = "GPLv3",
classifiers=[
diff --git a/test/test_distributed_data_object.py b/test/test_distributed_data_object.py
index 4ba08f4f72f175b388278d8b12127e2d0d6432d6..51deb1ce7c57d49f3e1261a2587a5a9a07ec50e2 100644
--- a/test/test_distributed_data_object.py
+++ b/test/test_distributed_data_object.py
@@ -18,7 +18,8 @@
from numpy.testing import assert_equal,\
assert_almost_equal,\
- assert_raises
+ assert_raises,\
+ assert_allclose
from nose_parameterized import parameterized
import unittest
@@ -136,9 +137,9 @@ def generate_data(global_shape, dtype, distribution_strategy,
local_shape[0] = 0
else:
local_shape[0] = global_shape[0] // np.ceil(size / 2.)
- number_of_extras = global_shape[
- 0] - local_shape[0] * np.ceil(size / 2.)
- if number_of_extras > rank:
+ number_of_extras = (global_shape[0] -
+ local_shape[0] * np.ceil(size / 2.))
+ if number_of_extras > rank//2:
local_shape[0] += 1
local_shape = tuple(local_shape)
@@ -1530,8 +1531,8 @@ class Test_contractions(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
- assert_almost_equal(getattr(obj, function)(), getattr(np, function)(a),
- decimal=4)
+ assert_allclose(getattr(obj, function)(), getattr(np, function)(a),
+ rtol=1e-4)
###############################################################################
@@ -1547,8 +1548,8 @@ class Test_contractions(unittest.TestCase):
(a, obj) = generate_data(global_shape, dtype,
distribution_strategy,
strictly_positive=True)
- assert_almost_equal(getattr(obj, function)(), getattr(np, function)(a),
- decimal=4)
+ assert_allclose(getattr(obj, function)(), getattr(np, function)(a),
+ rtol=1e-4)
###############################################################################
@@ -1557,9 +1558,13 @@ class Test_contractions(unittest.TestCase):
all_distribution_strategies
))
def test_argmin_argmax(self, dtype, distribution_strategy):
+ print (dtype, distribution_strategy)
global_shape = (8, 8)
(a, obj) = generate_data(global_shape, dtype,
- distribution_strategy)
+ distribution_strategy,
+ strictly_positive=True)
+ o_full = obj.get_full_data()
+ print (a, o_full)
assert_equal(obj.argmax(), np.argmax(a))
assert_equal(obj.argmin(), np.argmin(a))
assert_equal(obj.argmin_nonflat(),