Commit 5db4d734 authored by csongor's avatar csongor
Browse files

WIP: fix field.__init__ for multiple spaces

parent 7714b6cb
Pipeline #3933 skipped
...@@ -2,12 +2,12 @@ from __future__ import division ...@@ -2,12 +2,12 @@ from __future__ import division
import numpy as np import numpy as np
import pylab as pl import pylab as pl
from d2o import distributed_data_object,\ from d2o import distributed_data_object, \
STRATEGIES as DISTRIBUTION_STRATEGIES STRATEGIES as DISTRIBUTION_STRATEGIES
from nifty.config import about,\ from nifty.config import about, \
nifty_configuration as gc,\ nifty_configuration as gc, \
dependency_injector as gdi dependency_injector as gdi
from nifty.nifty_core import space from nifty.nifty_core import space
...@@ -104,7 +104,7 @@ class field(object): ...@@ -104,7 +104,7 @@ class field(object):
def __init__(self, domain=None, val=None, codomain=None, comm=gc[ def __init__(self, domain=None, val=None, codomain=None, comm=gc[
'default_comm'], copy=False, dtype=np.dtype('float64'), datamodel='not', 'default_comm'], copy=False, dtype=np.dtype('float64'), datamodel='not',
**kwargs): **kwargs):
""" """
Sets the attributes for a field class instance. Sets the attributes for a field class instance.
...@@ -256,12 +256,12 @@ class field(object): ...@@ -256,12 +256,12 @@ class field(object):
if len(domain) == 1: if len(domain) == 1:
return (domain[0].get_codomain(),) return (domain[0].get_codomain(),)
else: else:
# TODO implement for multiple domain get_codomain need codomain = tuple(space.get_codomain() for space in domain)
# calc_transform self.codomain = codomain
return np.empty((0,)) return codomain
def __len__(self): def __len__(self):
return int(self.get_dim(split=True)[0]) return int(self.get_dim()[0])
def apply_scalar_function(self, function, inplace=False): def apply_scalar_function(self, function, inplace=False):
if inplace: if inplace:
...@@ -312,11 +312,11 @@ class field(object): ...@@ -312,11 +312,11 @@ class field(object):
datamodel = self.datamodel datamodel = self.datamodel
if (domain is self.domain and if (domain is self.domain and
codomain is self.codomain and codomain is self.codomain and
dtype == self.dtype and dtype == self.dtype and
comm == self.comm and comm == self.comm and
datamodel == self.datamodel and datamodel == self.datamodel and
kwargs == {}): kwargs == {}):
new_field = self._fast_copy_empty() new_field = self._fast_copy_empty()
else: else:
new_field = field(domain=domain, codomain=codomain, dtype=dtype, new_field = field(domain=domain, codomain=codomain, dtype=dtype,
...@@ -336,8 +336,8 @@ class field(object): ...@@ -336,8 +336,8 @@ class field(object):
if new_val is not None: if new_val is not None:
if copy: if copy:
new_val = map( new_val = map(
lambda z: self.unary_operation(z, 'copy'), lambda z: self.unary_operation(z, 'copy'),
new_val) new_val)
self.val = map(lambda z: self.cast(z), new_val) self.val = map(lambda z: self.cast(z), new_val)
return self.val return self.val
...@@ -348,7 +348,7 @@ class field(object): ...@@ -348,7 +348,7 @@ class field(object):
def __getitem__(self, key): def __getitem__(self, key):
if np.isscalar(key) == True or isinstance(key, slice): if np.isscalar(key) == True or isinstance(key, slice):
key = (key, ) key = (key,)
if self.ishape == (): if self.ishape == ():
return self.domain.getitem(self.get_val(), key) return self.domain.getitem(self.get_val(), key)
else: else:
...@@ -362,7 +362,7 @@ class field(object): ...@@ -362,7 +362,7 @@ class field(object):
if is_data_container: if is_data_container:
gotten = map( gotten = map(
lambda z: self.domain.getitem( lambda z: self.domain.getitem(
z, key[len(self.ishape):]), z, key[len(self.ishape):]),
gotten) gotten)
else: else:
gotten = self.domain.getitem(gotten, gotten = self.domain.getitem(gotten,
...@@ -371,7 +371,7 @@ class field(object): ...@@ -371,7 +371,7 @@ class field(object):
def __setitem__(self, key, value): def __setitem__(self, key, value):
if np.isscalar(key) or isinstance(key, slice): if np.isscalar(key) or isinstance(key, slice):
key = (key, ) key = (key,)
if self.ishape == (): if self.ishape == ():
return self.domain.setitem(self.get_val(), value, key) return self.domain.setitem(self.get_val(), value, key)
else: else:
...@@ -385,7 +385,7 @@ class field(object): ...@@ -385,7 +385,7 @@ class field(object):
if is_data_container: if is_data_container:
gotten = map( gotten = map(
lambda z1, z2: self.domain.setitem( lambda z1, z2: self.domain.setitem(
z1, z2, key[len(self.ishape):]), z1, z2, key[len(self.ishape):]),
gotten, value) gotten, value)
else: else:
gotten = self.domain.setitem(gotten, value, gotten = self.domain.setitem(gotten, value,
...@@ -393,13 +393,13 @@ class field(object): ...@@ -393,13 +393,13 @@ class field(object):
else: else:
dummy = np.empty(self.ishape) dummy = np.empty(self.ishape)
gotten = self.val.__setitem__(key, self.cast( gotten = self.val.__setitem__(key, self.cast(
value, ishape=np.shape(dummy[key]))) value, ishape=np.shape(dummy[key])))
return gotten return gotten
def get_shape(self): def get_shape(self):
if len(self.domain) > 1: if len(self.domain) > 1:
global_shape = reduce(lambda x, y: x.get_shape()+y.get_shape(), global_shape = reduce(lambda x, y: x.get_shape() + y.get_shape(),
self.domain) self.domain)
else: else:
global_shape = self.domain[0].get_shape() global_shape = self.domain[0].get_shape()
...@@ -408,7 +408,7 @@ class field(object): ...@@ -408,7 +408,7 @@ class field(object):
else: else:
return () return ()
def get_dim(self, split=False): def get_dim(self):
""" """
Computes the (array) dimension of the underlying space. Computes the (array) dimension of the underlying space.
...@@ -430,7 +430,7 @@ class field(object): ...@@ -430,7 +430,7 @@ class field(object):
def get_dof(self, split=False): def get_dof(self, split=False):
dim = self.get_dim() dim = self.get_dim()
if np.issubdtype(self.dtype, np.complex): if np.issubdtype(self.dtype, np.complex):
return 2*dim return 2 * dim
else: else:
return dim return dim
...@@ -541,7 +541,7 @@ class field(object): ...@@ -541,7 +541,7 @@ class field(object):
return self.cast(x, dtype=dtype) return self.cast(x, dtype=dtype)
def _complement_cast(self, x): def _complement_cast(self, x):
#TODO implement complement cast for multiple spaces. # TODO implement complement cast for multiple spaces.
return x return x
def set_domain(self, new_domain=None, force=False): def set_domain(self, new_domain=None, force=False):
...@@ -559,7 +559,7 @@ class field(object): ...@@ -559,7 +559,7 @@ class field(object):
if new_domain is None: if new_domain is None:
new_domain = self.codomain.get_codomain() new_domain = self.codomain.get_codomain()
elif not force: elif not force:
assert(self.codomain.check_codomain(new_domain)) assert (self.codomain.check_codomain(new_domain))
self.domain = new_domain self.domain = new_domain
return self.domain return self.domain
...@@ -578,7 +578,7 @@ class field(object): ...@@ -578,7 +578,7 @@ class field(object):
if new_codomain is None: if new_codomain is None:
new_codomain = self.domain.get_codomain() new_codomain = self.domain.get_codomain()
elif not force: elif not force:
assert(self.domain.check_codomain(new_codomain)) assert (self.domain.check_codomain(new_codomain))
self.codomain = new_codomain self.codomain = new_codomain
return self.codomain return self.codomain
...@@ -609,7 +609,7 @@ class field(object): ...@@ -609,7 +609,7 @@ class field(object):
new_field = self.copy_empty() new_field = self.copy_empty()
new_val = map(lambda y: self.domain.calc_weight(y, power=power), new_val = map(lambda y: self.domain.calc_weight(y, power=power),
self.get_val()) self.get_val())
new_field.set_val(new_val=new_val) new_field.set_val(new_val=new_val)
return new_field return new_field
...@@ -630,9 +630,9 @@ class field(object): ...@@ -630,9 +630,9 @@ class field(object):
""" """
if q == 0.5: if q == 0.5:
return (self.dot(x=self))**(1 / 2) return (self.dot(x=self)) ** (1 / 2)
else: else:
return self.dot(x=self**(q - 1))**(1 / q) return self.dot(x=self ** (q - 1)) ** (1 / q)
def dot(self, x=None, axis=None, bare=False): def dot(self, x=None, axis=None, bare=False):
""" """
...@@ -788,7 +788,7 @@ class field(object): ...@@ -788,7 +788,7 @@ class field(object):
else: else:
new_codomain = new_domain.get_codomain() new_codomain = new_domain.get_codomain()
else: else:
assert(new_domain.check_codomain(new_codomain)) assert (new_domain.check_codomain(new_codomain))
new_val = map( new_val = map(
lambda z: self.domain.calc_transform( lambda z: self.domain.calc_transform(
...@@ -882,7 +882,7 @@ class field(object): ...@@ -882,7 +882,7 @@ class field(object):
Returns the power spectrum. Returns the power spectrum.
""" """
if("codomain" in kwargs): if ("codomain" in kwargs):
kwargs.__delitem__("codomain") kwargs.__delitem__("codomain")
about.warnings.cprint("WARNING: codomain was removed from kwargs.") about.warnings.cprint("WARNING: codomain was removed from kwargs.")
...@@ -1013,12 +1013,12 @@ class field(object): ...@@ -1013,12 +1013,12 @@ class field(object):
minmax = [self.min(), self.max()] minmax = [self.min(), self.max()]
mean = self.mean() mean = self.mean()
return "nifty_core.field instance\n- domain = " + \ return "nifty_core.field instance\n- domain = " + \
repr(self.domain) +\ repr(self.domain) + \
"\n- val = " + repr(self.get_val()) + \ "\n- val = " + repr(self.get_val()) + \
"\n - min.,max. = " + str(minmax) + \ "\n - min.,max. = " + str(minmax) + \
"\n - mean = " + str(mean) + \ "\n - mean = " + str(mean) + \
"\n- codomain = " + repr(self.codomain) + \ "\n- codomain = " + repr(self.codomain) + \
"\n- ishape = " + str(self.ishape) "\n- ishape = " + str(self.ishape)
def _unary_helper(self, x, op, **kwargs): def _unary_helper(self, x, op, **kwargs):
result = map( result = map(
...@@ -1233,13 +1233,10 @@ class field(object): ...@@ -1233,13 +1233,10 @@ class field(object):
other_val = other other_val = other
# bring other_val into the right shape # bring other_val into the right shape
if self.ishape == (): other_val = self._cast_to_d2o(other_val)
other_val = self._cast_to_scalar_helper(other_val)
else:
other_val = self._cast_to_tensor_helper(other_val)
new_val = map( new_val = map(
lambda z1, z2: self.domain.binary_operation(z1, z2, op=op, cast=0), lambda z1, z2: self.binary_operation(z1, z2, op=op, cast=0),
self.get_val(), self.get_val(),
other_val) other_val)
...@@ -1251,8 +1248,82 @@ class field(object): ...@@ -1251,8 +1248,82 @@ class field(object):
working_field.set_val(new_val=new_val) working_field.set_val(new_val=new_val)
return working_field return working_field
def unary_operation(self, x, op='None', axis=None, **kwargs):
"""
x must be a numpy array which is compatible with the space!
Valid operations are
"""
translation = {'pos': lambda y: getattr(y, '__pos__')(),
'neg': lambda y: getattr(y, '__neg__')(),
'abs': lambda y: getattr(y, '__abs__')(),
'real': lambda y: getattr(y, 'real'),
'imag': lambda y: getattr(y, 'imag'),
'nanmin': lambda y: getattr(y, 'nanmin')(axis=axis),
'amin': lambda y: getattr(y, 'amin')(axis=axis),
'nanmax': lambda y: getattr(y, 'nanmax')(axis=axis),
'amax': lambda y: getattr(y, 'amax')(axis=axis),
'median': lambda y: getattr(y, 'median')(axis=axis),
'mean': lambda y: getattr(y, 'mean')(axis=axis),
'std': lambda y: getattr(y, 'std')(axis=axis),
'var': lambda y: getattr(y, 'var')(axis=axis),
'argmin_nonflat': lambda y: getattr(y, 'argmin_nonflat')(
axis=axis),
'argmin': lambda y: getattr(y, 'argmin')(axis=axis),
'argmax_nonflat': lambda y: getattr(y, 'argmax_nonflat')(
axis=axis),
'argmax': lambda y: getattr(y, 'argmax')(axis=axis),
'conjugate': lambda y: getattr(y, 'conjugate')(),
'sum': lambda y: getattr(y, 'sum')(axis=axis),
'prod': lambda y: getattr(y, 'prod')(axis=axis),
'unique': lambda y: getattr(y, 'unique')(),
'copy': lambda y: getattr(y, 'copy')(),
'copy_empty': lambda y: getattr(y, 'copy_empty')(),
'isnan': lambda y: getattr(y, 'isnan')(),
'isinf': lambda y: getattr(y, 'isinf')(),
'isfinite': lambda y: getattr(y, 'isfinite')(),
'nan_to_num': lambda y: getattr(y, 'nan_to_num')(),
'all': lambda y: getattr(y, 'all')(axis=axis),
'any': lambda y: getattr(y, 'any')(axis=axis),
'None': lambda y: y}
return translation[op](x, **kwargs)
def binary_operation(self, x, y, op='None', cast=0):
translation = {'add': lambda z: getattr(z, '__add__'),
'radd': lambda z: getattr(z, '__radd__'),
'iadd': lambda z: getattr(z, '__iadd__'),
'sub': lambda z: getattr(z, '__sub__'),
'rsub': lambda z: getattr(z, '__rsub__'),
'isub': lambda z: getattr(z, '__isub__'),
'mul': lambda z: getattr(z, '__mul__'),
'rmul': lambda z: getattr(z, '__rmul__'),
'imul': lambda z: getattr(z, '__imul__'),
'div': lambda z: getattr(z, '__div__'),
'rdiv': lambda z: getattr(z, '__rdiv__'),
'idiv': lambda z: getattr(z, '__idiv__'),
'pow': lambda z: getattr(z, '__pow__'),
'rpow': lambda z: getattr(z, '__rpow__'),
'ipow': lambda z: getattr(z, '__ipow__'),
'ne': lambda z: getattr(z, '__ne__'),
'lt': lambda z: getattr(z, '__lt__'),
'le': lambda z: getattr(z, '__le__'),
'eq': lambda z: getattr(z, '__eq__'),
'ge': lambda z: getattr(z, '__ge__'),
'gt': lambda z: getattr(z, '__gt__'),
'None': lambda z: lambda u: u}
if (cast & 1) != 0:
x = self.cast(x)
if (cast & 2) != 0:
y = self.cast(y)
return translation[op](x)(y)
def __add__(self, other): def __add__(self, other):
return self._binary_helper(other, op='add') return self._binary_helper(other, op='add')
__radd__ = __add__ __radd__ = __add__
def __iadd__(self, other): def __iadd__(self, other):
...@@ -1269,6 +1340,7 @@ class field(object): ...@@ -1269,6 +1340,7 @@ class field(object):
def __mul__(self, other): def __mul__(self, other):
return self._binary_helper(other, op='mul') return self._binary_helper(other, op='mul')
__rmul__ = __mul__ __rmul__ = __mul__
def __imul__(self, other): def __imul__(self, other):
...@@ -1282,6 +1354,7 @@ class field(object): ...@@ -1282,6 +1354,7 @@ class field(object):
def __idiv__(self, other): def __idiv__(self, other):
return self._binary_helper(other, op='idiv', inplace=True) return self._binary_helper(other, op='idiv', inplace=True)
__truediv__ = __div__ __truediv__ = __div__
__itruediv__ = __idiv__ __itruediv__ = __idiv__
...@@ -1318,6 +1391,7 @@ class field(object): ...@@ -1318,6 +1391,7 @@ class field(object):
def __gt__(self, other): def __gt__(self, other):
return self._binary_helper(other, op='gt') return self._binary_helper(other, op='gt')
class EmptyField(field): class EmptyField(field):
def __init__(self): def __init__(self):
pass pass
...@@ -123,3 +123,31 @@ class Test_field_init(unittest.TestCase): ...@@ -123,3 +123,31 @@ class Test_field_init(unittest.TestCase):
f = field(domain=(s,), dtype=s.dtype, datamodel=datamodel) f = field(domain=(s,), dtype=s.dtype, datamodel=datamodel)
assert (f.domain[0] is s) assert (f.domain[0] is s)
assert (s.check_codomain(f.codomain[0])) assert (s.check_codomain(f.codomain[0]))
assert (s.get_shape() == f.get_shape())
class Test_field_multiple_init(unittest.TestCase):
@parameterized.expand(
itertools.product([(1,)],
[True],
[0],
[None],
[False],
fft_modules,
DATAMODELS['rg_space']),
testcase_func_name=custom_name_func)
def test_multiple_space_init(self, shape, zerocenter,
complexity, distances, harmonic,
fft_module, datamodel):
s1 = rg_space(shape=shape, zerocenter=zerocenter,
complexity=complexity, distances=distances,
harmonic=harmonic, fft_module=fft_module)
s2 = rg_space(shape=shape, zerocenter=zerocenter,
complexity=complexity, distances=distances,
harmonic=harmonic, fft_module=fft_module)
f = field(domain=(s1, s2), dtype=s1.dtype, datamodel=datamodel)
assert (f.domain[0] is s1)
assert (f.domain[1] is s2)
assert (s1.check_codomain(f.codomain[0]))
assert (s2.check_codomain(f.codomain[1]))
assert (s1.get_shape() + s2.get_shape() == f.get_shape())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment