Commit 5db4d734 authored by csongor's avatar csongor
Browse files

WIP: fix field.__init__ for multiple spaces

parent 7714b6cb
Pipeline #3933 skipped
......@@ -2,11 +2,11 @@ from __future__ import division
import numpy as np
import pylab as pl
from d2o import distributed_data_object,\
from d2o import distributed_data_object, \
STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.config import about,\
nifty_configuration as gc,\
from nifty.config import about, \
nifty_configuration as gc, \
dependency_injector as gdi
from nifty.nifty_core import space
......@@ -256,12 +256,12 @@ class field(object):
if len(domain) == 1:
return (domain[0].get_codomain(),)
else:
# TODO implement for multiple domain get_codomain need
# calc_transform
return np.empty((0,))
codomain = tuple(space.get_codomain() for space in domain)
self.codomain = codomain
return codomain
def __len__(self):
return int(self.get_dim(split=True)[0])
return int(self.get_dim()[0])
def apply_scalar_function(self, function, inplace=False):
if inplace:
......@@ -348,7 +348,7 @@ class field(object):
def __getitem__(self, key):
if np.isscalar(key) == True or isinstance(key, slice):
key = (key, )
key = (key,)
if self.ishape == ():
return self.domain.getitem(self.get_val(), key)
else:
......@@ -371,7 +371,7 @@ class field(object):
def __setitem__(self, key, value):
if np.isscalar(key) or isinstance(key, slice):
key = (key, )
key = (key,)
if self.ishape == ():
return self.domain.setitem(self.get_val(), value, key)
else:
......@@ -398,7 +398,7 @@ class field(object):
def get_shape(self):
if len(self.domain) > 1:
global_shape = reduce(lambda x, y: x.get_shape()+y.get_shape(),
global_shape = reduce(lambda x, y: x.get_shape() + y.get_shape(),
self.domain)
else:
global_shape = self.domain[0].get_shape()
......@@ -408,7 +408,7 @@ class field(object):
else:
return ()
def get_dim(self, split=False):
def get_dim(self):
"""
Computes the (array) dimension of the underlying space.
......@@ -430,7 +430,7 @@ class field(object):
def get_dof(self, split=False):
dim = self.get_dim()
if np.issubdtype(self.dtype, np.complex):
return 2*dim
return 2 * dim
else:
return dim
......@@ -541,7 +541,7 @@ class field(object):
return self.cast(x, dtype=dtype)
def _complement_cast(self, x):
#TODO implement complement cast for multiple spaces.
# TODO implement complement cast for multiple spaces.
return x
def set_domain(self, new_domain=None, force=False):
......@@ -559,7 +559,7 @@ class field(object):
if new_domain is None:
new_domain = self.codomain.get_codomain()
elif not force:
assert(self.codomain.check_codomain(new_domain))
assert (self.codomain.check_codomain(new_domain))
self.domain = new_domain
return self.domain
......@@ -578,7 +578,7 @@ class field(object):
if new_codomain is None:
new_codomain = self.domain.get_codomain()
elif not force:
assert(self.domain.check_codomain(new_codomain))
assert (self.domain.check_codomain(new_codomain))
self.codomain = new_codomain
return self.codomain
......@@ -630,9 +630,9 @@ class field(object):
"""
if q == 0.5:
return (self.dot(x=self))**(1 / 2)
return (self.dot(x=self)) ** (1 / 2)
else:
return self.dot(x=self**(q - 1))**(1 / q)
return self.dot(x=self ** (q - 1)) ** (1 / q)
def dot(self, x=None, axis=None, bare=False):
"""
......@@ -788,7 +788,7 @@ class field(object):
else:
new_codomain = new_domain.get_codomain()
else:
assert(new_domain.check_codomain(new_codomain))
assert (new_domain.check_codomain(new_codomain))
new_val = map(
lambda z: self.domain.calc_transform(
......@@ -882,7 +882,7 @@ class field(object):
Returns the power spectrum.
"""
if("codomain" in kwargs):
if ("codomain" in kwargs):
kwargs.__delitem__("codomain")
about.warnings.cprint("WARNING: codomain was removed from kwargs.")
......@@ -1013,7 +1013,7 @@ class field(object):
minmax = [self.min(), self.max()]
mean = self.mean()
return "nifty_core.field instance\n- domain = " + \
repr(self.domain) +\
repr(self.domain) + \
"\n- val = " + repr(self.get_val()) + \
"\n - min.,max. = " + str(minmax) + \
"\n - mean = " + str(mean) + \
......@@ -1233,13 +1233,10 @@ class field(object):
other_val = other
# bring other_val into the right shape
if self.ishape == ():
other_val = self._cast_to_scalar_helper(other_val)
else:
other_val = self._cast_to_tensor_helper(other_val)
other_val = self._cast_to_d2o(other_val)
new_val = map(
lambda z1, z2: self.domain.binary_operation(z1, z2, op=op, cast=0),
lambda z1, z2: self.binary_operation(z1, z2, op=op, cast=0),
self.get_val(),
other_val)
......@@ -1251,8 +1248,82 @@ class field(object):
working_field.set_val(new_val=new_val)
return working_field
def unary_operation(self, x, op='None', axis=None, **kwargs):
"""
x must be a numpy array which is compatible with the space!
Valid operations are
"""
translation = {'pos': lambda y: getattr(y, '__pos__')(),
'neg': lambda y: getattr(y, '__neg__')(),
'abs': lambda y: getattr(y, '__abs__')(),
'real': lambda y: getattr(y, 'real'),
'imag': lambda y: getattr(y, 'imag'),
'nanmin': lambda y: getattr(y, 'nanmin')(axis=axis),
'amin': lambda y: getattr(y, 'amin')(axis=axis),
'nanmax': lambda y: getattr(y, 'nanmax')(axis=axis),
'amax': lambda y: getattr(y, 'amax')(axis=axis),
'median': lambda y: getattr(y, 'median')(axis=axis),
'mean': lambda y: getattr(y, 'mean')(axis=axis),
'std': lambda y: getattr(y, 'std')(axis=axis),
'var': lambda y: getattr(y, 'var')(axis=axis),
'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')(
axis=axis),
'argmin': lambda y: getattr(y, 'argmin')(axis=axis),
'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')(
axis=axis),
'argmax': lambda y: getattr(y, 'argmax')(axis=axis),
'conjugate': lambda y: getattr(y, 'conjugate')(),
'sum': lambda y: getattr(y, 'sum')(axis=axis),
'prod': lambda y: getattr(y, 'prod')(axis=axis),
'unique': lambda y: getattr(y, 'unique')(),
'copy': lambda y: getattr(y, 'copy')(),
'copy_empty': lambda y: getattr(y, 'copy_empty')(),
'isnan': lambda y: getattr(y, 'isnan')(),
'isinf': lambda y: getattr(y, 'isinf')(),
'isfinite': lambda y: getattr(y, 'isfinite')(),
'nan_to_num': lambda y: getattr(y, 'nan_to_num')(),
'all': lambda y: getattr(y, 'all')(axis=axis),
'any': lambda y: getattr(y, 'any')(axis=axis),
'None': lambda y: y}
return translation[op](x, **kwargs)
def binary_operation(self, x, y, op='None', cast=0):
translation = {'add': lambda z: getattr(z, '__add__'),
'radd': lambda z: getattr(z, '__radd__'),
'iadd': lambda z: getattr(z, '__iadd__'),
'sub': lambda z: getattr(z, '__sub__'),
'rsub': lambda z: getattr(z, '__rsub__'),
'isub': lambda z: getattr(z, '__isub__'),
'mul': lambda z: getattr(z, '__mul__'),
'rmul': lambda z: getattr(z, '__rmul__'),
'imul': lambda z: getattr(z, '__imul__'),
'div': lambda z: getattr(z, '__div__'),
'rdiv': lambda z: getattr(z, '__rdiv__'),
'idiv': lambda z: getattr(z, '__idiv__'),
'pow': lambda z: getattr(z, '__pow__'),
'rpow': lambda z: getattr(z, '__rpow__'),
'ipow': lambda z: getattr(z, '__ipow__'),
'ne': lambda z: getattr(z, '__ne__'),
'lt': lambda z: getattr(z, '__lt__'),
'le': lambda z: getattr(z, '__le__'),
'eq': lambda z: getattr(z, '__eq__'),
'ge': lambda z: getattr(z, '__ge__'),
'gt': lambda z: getattr(z, '__gt__'),
'None': lambda z: lambda u: u}
if (cast & 1) != 0:
x = self.cast(x)
if (cast & 2) != 0:
y = self.cast(y)
return translation[op](x)(y)
def __add__(self, other):
return self._binary_helper(other, op='add')
__radd__ = __add__
def __iadd__(self, other):
......@@ -1269,6 +1340,7 @@ class field(object):
def __mul__(self, other):
return self._binary_helper(other, op='mul')
__rmul__ = __mul__
def __imul__(self, other):
......@@ -1282,6 +1354,7 @@ class field(object):
def __idiv__(self, other):
return self._binary_helper(other, op='idiv', inplace=True)
__truediv__ = __div__
__itruediv__ = __idiv__
......@@ -1318,6 +1391,7 @@ class field(object):
def __gt__(self, other):
return self._binary_helper(other, op='gt')
class EmptyField(field):
def __init__(self):
pass
......@@ -123,3 +123,31 @@ class Test_field_init(unittest.TestCase):
f = field(domain=(s,), dtype=s.dtype, datamodel=datamodel)
assert (f.domain[0] is s)
assert (s.check_codomain(f.codomain[0]))
assert (s.get_shape() == f.get_shape())
class Test_field_multiple_init(unittest.TestCase):
@parameterized.expand(
itertools.product([(1,)],
[True],
[0],
[None],
[False],
fft_modules,
DATAMODELS['rg_space']),
testcase_func_name=custom_name_func)
def test_multiple_space_init(self, shape, zerocenter,
complexity, distances, harmonic,
fft_module, datamodel):
s1 = rg_space(shape=shape, zerocenter=zerocenter,
complexity=complexity, distances=distances,
harmonic=harmonic, fft_module=fft_module)
s2 = rg_space(shape=shape, zerocenter=zerocenter,
complexity=complexity, distances=distances,
harmonic=harmonic, fft_module=fft_module)
f = field(domain=(s1, s2), dtype=s1.dtype, datamodel=datamodel)
assert (f.domain[0] is s1)
assert (f.domain[1] is s2)
assert (s1.check_codomain(f.codomain[0]))
assert (s2.check_codomain(f.codomain[1]))
assert (s1.get_shape() + s2.get_shape() == f.get_shape())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment