Commit f399b0da authored by csongor's avatar csongor
Browse files

Add axis keyword to point_space and test it

parent 51428d74
Pipeline #1794 skipped
......@@ -892,7 +892,7 @@ class point_space(space):
def apply_scalar_function(self, x, function, inplace=False):
return x.apply_scalar_function(function, inplace=inplace)
def unary_operation(self, x, op='None', **kwargs):
def unary_operation(self, x, op='None', axis=None, **kwargs):
"""
x must be a numpy array which is compatible with the space!
Valid operations are
......@@ -903,21 +903,21 @@ class point_space(space):
'abs': lambda y: getattr(y, '__abs__')(),
'real': lambda y: getattr(y, 'real'),
'imag': lambda y: getattr(y, 'imag'),
'nanmin': lambda y: getattr(y, 'nanmin')(),
'amin': lambda y: getattr(y, 'amin')(),
'nanmax': lambda y: getattr(y, 'nanmax')(),
'amax': lambda y: getattr(y, 'amax')(),
'median': lambda y: getattr(y, 'median')(),
'mean': lambda y: getattr(y, 'mean')(),
'std': lambda y: getattr(y, 'std')(),
'var': lambda y: getattr(y, 'var')(),
'nanmin': lambda y: getattr(y, 'nanmin')(axis=axis),
'amin': lambda y: getattr(y, 'amin')(axis=axis),
'nanmax': lambda y: getattr(y, 'nanmax')(axis=axis),
'amax': lambda y: getattr(y, 'amax')(axis=axis),
'median': lambda y: getattr(y, 'median')(axis=axis),
'mean': lambda y: getattr(y, 'mean')(axis=axis),
'std': lambda y: getattr(y, 'std')(axis=axis),
'var': lambda y: getattr(y, 'var')(axis=axis),
'argmin': lambda y: getattr(y, 'argmin_nonflat')(),
'argmin_flat': lambda y: getattr(y, 'argmin')(),
'argmax': lambda y: getattr(y, 'argmax_nonflat')(),
'argmax_flat': lambda y: getattr(y, 'argmax')(),
'conjugate': lambda y: getattr(y, 'conjugate')(),
'sum': lambda y: getattr(y, 'sum')(),
'prod': lambda y: getattr(y, 'prod')(),
'sum': lambda y: getattr(y, 'sum')(axis=axis),
'prod': lambda y: getattr(y, 'prod')(axis=axis),
'unique': lambda y: getattr(y, 'unique')(),
'copy': lambda y: getattr(y, 'copy')(),
'copy_empty': lambda y: getattr(y, 'copy_empty')(),
......@@ -925,8 +925,8 @@ class point_space(space):
'isinf': lambda y: getattr(y, 'isinf')(),
'isfinite': lambda y: getattr(y, 'isfinite')(),
'nan_to_num': lambda y: getattr(y, 'nan_to_num')(),
'all': lambda y: getattr(y, 'all')(),
'any': lambda y: getattr(y, 'any')(),
'all': lambda y: getattr(y, 'all')(axis=axis),
'any': lambda y: getattr(y, 'any')(axis=axis),
'None': lambda y: y}
return translation[op](x, **kwargs)
......
......@@ -27,6 +27,8 @@ from nifty.nifty_power_indices import power_indices
from nifty.nifty_utilities import _hermitianize_inverter as \
hermitianize_inverter
from nifty.operators.nifty_operators import power_operator
available = []
try:
from nifty import lm_space
......@@ -178,6 +180,22 @@ def generate_space(name):
return space_dict[name]
def generate_space_with_size(name, num, datamodel='fftw'):
space_dict = {'space': space(),
'point_space': point_space(num, datamodel=datamodel),
'rg_space': rg_space((num, num), datamodel=datamodel),
}
if 'lm_space' in available:
space_dict['lm_space'] = lm_space(mmax=num, lmax=num,
datamodel=datamodel)
if 'hp_space' in available:
space_dict['hp_space'] = hp_space(num, datamodel=datamodel)
if 'gl_space' in available:
space_dict['gl_space'] = gl_space(nlat=num, nlon=num,
datamodel=datamodel)
return space_dict[name]
def generate_data(space):
a = np.arange(space.get_dim()).reshape(space.get_shape())
data = space.cast(a)
......@@ -1334,4 +1352,21 @@ class Test_Lm_Space(unittest.TestCase):
print all_spaces
print generate_space('rg_space')
\ No newline at end of file
print generate_space('rg_space')
class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(point_like_spaces, [8, 16],
['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'nanmin', 'nanmax'], [None, (0,)],
DATAMODELS['point_space']),
testcase_func_name=custom_name_func)
def test_binary_operations(self, name, num, op, axis, datamodel):
s = generate_space_with_size(name, np.prod(num), datamodel=datamodel)
d = generate_data(s)
a = d.get_full_data()
assert_almost_equal(s.unary_operation(d, op, axis=axis),
getattr(np, op)(a, axis=axis), decimal=4)
if name in ['rg_space']:
assert_almost_equal(s.unary_operation(d, op, axis=(0, 1)),
getattr(np, op)(a, axis=(0, 1)), decimal=4)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment