Commit 71fb7400 authored by csongor's avatar csongor
Browse files

WIP: unary operations fixes

parent f78cf8d0
Pipeline #4574 skipped
......@@ -359,7 +359,7 @@ class field(object):
new_val = map(
lambda z: self.unary_operation(z, 'copy'),
new_val)
self.val = map(lambda z: self.cast(z), new_val)
self.val = self.cast(new_val)
return self.val
def get_val(self):
......@@ -996,11 +996,8 @@ class field(object):
"\n- codomain = " + repr(self.codomain) + \
"\n- ishape = " + str(self.ishape)
def _unary_helper(self, x, op, **kwargs):
result = map(
lambda z: self.domain.unary_operation(z, op=op, **kwargs),
self.get_val())
return result
def all(self, **kwargs):
return self._unary_operation(self.get_val(), op='all', **kwargs)
def min(self, ignore=False, **kwargs):
"""
......@@ -1021,10 +1018,10 @@ class field(object):
np.amin, np.nanmin
"""
return self._unary_helper(self.get_val(), op='amin', **kwargs)
return self._unary_operation(self.get_val(), op='amin', **kwargs)
def nanmin(self, **kwargs):
return self._unary_helper(self.get_val(), op='nanmin', **kwargs)
return self._unary_operation(self.get_val(), op='nanmin', **kwargs)
def max(self, **kwargs):
"""
......@@ -1045,10 +1042,10 @@ class field(object):
np.amax, np.nanmax
"""
return self._unary_helper(self.get_val(), op='amax', **kwargs)
return self._unary_operation(self.get_val(), op='amax', **kwargs)
def nanmax(self, **kwargs):
return self._unary_helper(self.get_val(), op='nanmax', **kwargs)
return self._unary_operation(self.get_val(), op='nanmax', **kwargs)
def median(self, **kwargs):
"""
......@@ -1064,7 +1061,7 @@ class field(object):
np.median
"""
return self._unary_helper(self.get_val(), op='median',
return self._unary_operation(self.get_val(), op='median',
**kwargs)
def mean(self, **kwargs):
......@@ -1081,7 +1078,7 @@ class field(object):
np.mean
"""
return self._unary_helper(self.get_val(), op='mean',
return self._unary_operation(self.get_val(), op='mean',
**kwargs)
def std(self, **kwargs):
......@@ -1098,7 +1095,7 @@ class field(object):
np.std
"""
return self._unary_helper(self.get_val(), op='std',
return self._unary_operation(self.get_val(), op='std',
**kwargs)
def var(self, **kwargs):
......@@ -1115,7 +1112,7 @@ class field(object):
np.var
"""
return self._unary_helper(self.get_val(), op='var',
return self._unary_operation(self.get_val(), op='var',
**kwargs)
def argmin(self, split=True, **kwargs):
......@@ -1141,10 +1138,10 @@ class field(object):
"""
if split:
return self._unary_helper(self.get_val(), op='argmin_nonflat',
return self._unary_operation(self.get_val(), op='argmin_nonflat',
**kwargs)
else:
return self._unary_helper(self.get_val(), op='argmin',
return self._unary_operation(self.get_val(), op='argmin',
**kwargs)
def argmax(self, split=True, **kwargs):
......@@ -1170,29 +1167,29 @@ class field(object):
"""
if split:
return self._unary_helper(self.get_val(), op='argmax_nonflat',
return self._unary_operation(self.get_val(), op='argmax_nonflat',
**kwargs)
else:
return self._unary_helper(self.get_val(), op='argmax',
return self._unary_operation(self.get_val(), op='argmax',
**kwargs)
# TODO: Implement the full range of unary and binary operotions
def __pos__(self):
new_field = self.copy_empty()
new_val = self._unary_helper(self.get_val(), op='pos')
new_val = self._unary_operation(self.get_val(), op='pos')
new_field.set_val(new_val=new_val)
return new_field
def __neg__(self):
new_field = self.copy_empty()
new_val = self._unary_helper(self.get_val(), op='neg')
new_val = self._unary_operation(self.get_val(), op='neg')
new_field.set_val(new_val=new_val)
return new_field
def __abs__(self):
new_field = self.copy_empty()
new_val = self._unary_helper(self.get_val(), op='abs')
new_val = self._unary_operation(self.get_val(), op='abs')
new_field.set_val(new_val=new_val)
return new_field
......@@ -1224,7 +1221,7 @@ class field(object):
working_field.set_val(new_val=new_val)
return working_field
def unary_operation(self, x, op='None', axis=None, **kwargs):
def _unary_operation(self, x, op='None', axis=None, **kwargs):
"""
x must be a numpy array which is compatible with the space!
Valid operations are
......
......@@ -9,6 +9,8 @@ import unittest
import itertools
import numpy as np
from d2o import distributed_data_object
from nifty import space, \
point_space, \
rg_space, \
......@@ -101,6 +103,21 @@ for param in itertools.product([(1,), (4, 6), (5, 8)],
fft_module=param[6]), param[5]]]
def generate_space_with_size(name, num):
space_dict = {'space': space(),
'point_space': point_space(num),
'rg_space': rg_space((num, num)),
'lm_space': lm_space(mmax=num, lmax=num),
'hp_space': hp_space(num),
'gl_space': gl_space(nlat=num, nlon=num),
}
return space_dict[name]
def generate_data(space):
a = np.arange(space.get_dim()).reshape(space.get_shape())
return distributed_data_object(a)
###############################################################################
###############################################################################
......@@ -151,3 +168,21 @@ class Test_field_multiple_init(unittest.TestCase):
assert (s1.check_codomain(f.codomain[0]))
assert (s2.check_codomain(f.codomain[1]))
assert (s1.get_shape() + s2.get_shape() == f.get_shape())
class Test_axis(unittest.TestCase):
@parameterized.expand(
itertools.product(point_like_spaces, [4],
['sum', 'prod', 'mean', 'var', 'std', 'median', 'all',
'any', 'amin', 'nanmin', 'argmin', 'amax', 'nanmax',
'argmax'],
[None, (0,)],
DATAMODELS['rg_space']),
testcase_func_name=custom_name_func)
def test_unary_operations(self, name, num, op, axis, datamodel):
s = generate_space_with_size(name, num)
d = generate_data(s)
a = d.get_full_data()
f = field(val=d, domain=(s,), dtype=s.dtype, datamodel=datamodel)
assert_almost_equal(getattr(f, op)(axis=axis),
getattr(np, op)(a, axis=axis), decimal=4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment