Commit 1c76affd authored by csongor's avatar csongor
Browse files

unary operations fixes

parent 71fb7400
......@@ -20,8 +20,7 @@
## along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import division
from nifty.keepers import about,\
global_dependency_injector as gdi
from nifty.config import about
from distutils.version import LooseVersion as lv
......
......@@ -996,9 +996,18 @@ class field(object):
"\n- codomain = " + repr(self.codomain) + \
"\n- ishape = " + str(self.ishape)
def sum(self, **kwargs):
return self._unary_operation(self.get_val(), op='sum', **kwargs)
def prod(self, **kwargs):
return self._unary_operation(self.get_val(), op='prod', **kwargs)
def all(self, **kwargs):
return self._unary_operation(self.get_val(), op='all', **kwargs)
def any(self, **kwargs):
return self._unary_operation(self.get_val(), op='any', **kwargs)
def min(self, ignore=False, **kwargs):
"""
Returns the minimum of the field values.
......
......@@ -619,8 +619,8 @@ class rg_power_indices(power_indices):
class lm_power_indices(power_indices):
def __init__(self, lmax, dim, datamodel,
allowed_distribution_strategies,
def __init__(self, lmax, dim,
allowed_distribution_strategies, datamodel='not',
zerocentered=False, log=False, nbin=None,
binbounds=None, comm=None):
"""
......
......@@ -107,7 +107,7 @@ def generate_space_with_size(name, num):
space_dict = {'space': space(),
'point_space': point_space(num),
'rg_space': rg_space((num, num)),
'lm_space': lm_space(mmax=num, lmax=num),
'lm_space': lm_space(mmax=num+1, lmax=num+1),
'hp_space': hp_space(num),
'gl_space': gl_space(nlat=num, nlon=num),
}
......@@ -143,6 +143,19 @@ class Test_field_init(unittest.TestCase):
assert (s.get_shape() == f.get_shape())
class Test_field_init2(unittest.TestCase):
@parameterized.expand(
itertools.product(point_like_spaces, [4],
DATAMODELS['rg_space']),
testcase_func_name=custom_name_func)
def test_successfull_init_and_attributes(self, name, num, datamodel):
s = generate_space_with_size(name, num)
d = generate_data(s)
f = field(val=d, domain=(s,), dtype=s.dtype, datamodel=datamodel)
assert (f.domain[0] is s)
assert (s.check_codomain(f.codomain[0]))
assert (s.get_shape() == f.get_shape())
class Test_field_multiple_init(unittest.TestCase):
@parameterized.expand(
itertools.product([(1,)],
......@@ -174,7 +187,7 @@ class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(point_like_spaces, [4],
['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'amin', 'nanmin', 'argmin', 'amax', 'nanmax',
'any', 'min', 'nanmin', 'argmin', 'max', 'nanmax',
'argmax'],
[None, (0,)],
DATAMODELS['rg_space']),
......@@ -184,5 +197,9 @@ class Test_axis(unittest.TestCase):
d = generate_data(s)
a = d.get_full_data()
f = field(val=d, domain=(s,), dtype=s.dtype, datamodel=datamodel)
assert_almost_equal(getattr(f, op)(axis=axis),
if op in ['argmin','argmax']:
assert_almost_equal(getattr(f, op)(),
getattr(np, op)(a), decimal=4)
else:
assert_almost_equal(getattr(f, op)(axis=axis),
getattr(np, op)(a, axis=axis), decimal=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment