Commit 7714b6cb authored by csongor's avatar csongor
Browse files

WIP: fix field.__init__

parent 88ed077d
Pipeline #3932 skipped
...@@ -9,6 +9,8 @@ from nifty.config import about,\ ...@@ -9,6 +9,8 @@ from nifty.config import about,\
nifty_configuration as gc,\ nifty_configuration as gc,\
dependency_injector as gdi dependency_injector as gdi
from nifty.nifty_core import space
import nifty.nifty_utilities as utilities import nifty.nifty_utilities as utilities
POINT_DISTRIBUTION_STRATEGIES = DISTRIBUTION_STRATEGIES['global'] POINT_DISTRIBUTION_STRATEGIES = DISTRIBUTION_STRATEGIES['global']
...@@ -202,11 +204,10 @@ class field(object): ...@@ -202,11 +204,10 @@ class field(object):
if val is None: if val is None:
if kwargs == {}: if kwargs == {}:
val = self._map(lambda: self.cast((0,))) val = map(lambda z: self.cast(z), (0,))
else: else:
val = self._map(lambda: self.domain.get_random_values( val = map(lambda z: self.domain.get_random_values(
codomain=self.codomain, codomain=z, **kwargs), self.codomain)
**kwargs))
self.set_val(new_val=val, copy=copy) self.set_val(new_val=val, copy=copy)
def _parse_comm(self, comm): def _parse_comm(self, comm):
...@@ -229,7 +230,7 @@ class field(object): ...@@ -229,7 +230,7 @@ class field(object):
return result_comm return result_comm
def check_valid_domain(self, domain): def check_valid_domain(self, domain):
if not isinstance(domain, np.ndarray): if not isinstance(domain, tuple):
raise TypeError(about._errors.cstring( raise TypeError(about._errors.cstring(
"ERROR: The given domain is not a list.")) "ERROR: The given domain is not a list."))
for d in domain: for d in domain:
...@@ -245,15 +246,15 @@ class field(object): ...@@ -245,15 +246,15 @@ class field(object):
def check_codomain(self, domain, codomain): def check_codomain(self, domain, codomain):
if codomain is None: if codomain is None:
return False return False
if domain.shape == codomain.shape: if len(domain) == len(codomain):
return np.all(map((lambda d, c: d._check_codomain(c)), domain, return np.all(map((lambda d, c: d._check_codomain(c)), domain,
codomain)) codomain))
else: else:
return False return False
def get_codomain(self, domain): def get_codomain(self, domain):
if domain.shape == (1,): if len(domain) == 1:
return np.array(domain[0].get_codomain()) return (domain[0].get_codomain(),)
else: else:
# TODO implement for multiple domain get_codomain need # TODO implement for multiple domain get_codomain need
# calc_transform # calc_transform
...@@ -268,7 +269,7 @@ class field(object): ...@@ -268,7 +269,7 @@ class field(object):
else: else:
working_field = self.copy_empty() working_field = self.copy_empty()
data_object = self._map( data_object = map(
lambda z: self.domain.apply_scalar_function(z, function, inplace), lambda z: self.domain.apply_scalar_function(z, function, inplace),
self.get_val()) self.get_val())
...@@ -276,7 +277,7 @@ class field(object): ...@@ -276,7 +277,7 @@ class field(object):
return working_field return working_field
def copy(self, domain=None, codomain=None): def copy(self, domain=None, codomain=None):
copied_val = self._map( copied_val = map(
lambda z: self.domain.unary_operation(z, op='copy'), lambda z: self.domain.unary_operation(z, op='copy'),
self.get_val()) self.get_val())
new_field = self.copy_empty(domain=domain, codomain=codomain) new_field = self.copy_empty(domain=domain, codomain=codomain)
...@@ -334,10 +335,10 @@ class field(object): ...@@ -334,10 +335,10 @@ class field(object):
""" """
if new_val is not None: if new_val is not None:
if copy: if copy:
new_val = self._map( new_val = map(
lambda z: self.domain.unary_operation(z, 'copy'), lambda z: self.unary_operation(z, 'copy'),
new_val) new_val)
self.val = self._map(lambda z: self.domain.cast(z), new_val) self.val = map(lambda z: self.cast(z), new_val)
return self.val return self.val
def get_val(self): def get_val(self):
...@@ -359,7 +360,7 @@ class field(object): ...@@ -359,7 +360,7 @@ class field(object):
if len(key) > len(self.ishape): if len(key) > len(self.ishape):
if is_data_container: if is_data_container:
gotten = self._map( gotten = map(
lambda z: self.domain.getitem( lambda z: self.domain.getitem(
z, key[len(self.ishape):]), z, key[len(self.ishape):]),
gotten) gotten)
...@@ -382,7 +383,7 @@ class field(object): ...@@ -382,7 +383,7 @@ class field(object):
is_data_container = False is_data_container = False
if is_data_container: if is_data_container:
gotten = self._map( gotten = map(
lambda z1, z2: self.domain.setitem( lambda z1, z2: self.domain.setitem(
z1, z2, key[len(self.ishape):]), z1, z2, key[len(self.ishape):]),
gotten, value) gotten, value)
...@@ -607,7 +608,7 @@ class field(object): ...@@ -607,7 +608,7 @@ class field(object):
else: else:
new_field = self.copy_empty() new_field = self.copy_empty()
new_val = self._map(lambda y: self.domain.calc_weight(y, power=power), new_val = map(lambda y: self.domain.calc_weight(y, power=power),
self.get_val()) self.get_val())
new_field.set_val(new_val=new_val) new_field.set_val(new_val=new_val)
...@@ -675,12 +676,12 @@ class field(object): ...@@ -675,12 +676,12 @@ class field(object):
casted_x = self._cast_to_ishape(x) casted_x = self._cast_to_ishape(x)
# Compute the dot respecting the fact of discrete/continous spaces # Compute the dot respecting the fact of discrete/continous spaces
if self.domain.discrete or bare: if self.domain.discrete or bare:
result = self._map( result = map(
lambda z1, z2: self.domain.calc_dot(z1, z2), lambda z1, z2: self.domain.calc_dot(z1, z2),
self.get_val(), self.get_val(),
casted_x) casted_x)
else: else:
result = self._map( result = map(
lambda z1, z2: self.domain.calc_dot( lambda z1, z2: self.domain.calc_dot(
self.domain.calc_weight(z1, power=1), self.domain.calc_weight(z1, power=1),
z2), z2),
...@@ -744,7 +745,7 @@ class field(object): ...@@ -744,7 +745,7 @@ class field(object):
else: else:
work_field = self.copy_empty() work_field = self.copy_empty()
new_val = self._map( new_val = map(
lambda z: self.domain.unary_operation(z, 'conjugate'), lambda z: self.domain.unary_operation(z, 'conjugate'),
self.get_val()) self.get_val())
work_field.set_val(new_val=new_val) work_field.set_val(new_val=new_val)
...@@ -789,7 +790,7 @@ class field(object): ...@@ -789,7 +790,7 @@ class field(object):
else: else:
assert(new_domain.check_codomain(new_codomain)) assert(new_domain.check_codomain(new_codomain))
new_val = self._map( new_val = map(
lambda z: self.domain.calc_transform( lambda z: self.domain.calc_transform(
z, codomain=new_domain, **kwargs), z, codomain=new_domain, **kwargs),
self.get_val()) self.get_val())
...@@ -835,7 +836,7 @@ class field(object): ...@@ -835,7 +836,7 @@ class field(object):
else: else:
new_field = self.copy_empty() new_field = self.copy_empty()
new_val = self._map( new_val = map(
lambda z: self.domain.calc_smooth(z, sigma=sigma, **kwargs), lambda z: self.domain.calc_smooth(z, sigma=sigma, **kwargs),
self.get_val()) self.get_val())
...@@ -885,7 +886,7 @@ class field(object): ...@@ -885,7 +886,7 @@ class field(object):
kwargs.__delitem__("codomain") kwargs.__delitem__("codomain")
about.warnings.cprint("WARNING: codomain was removed from kwargs.") about.warnings.cprint("WARNING: codomain was removed from kwargs.")
power_spectrum = self._map( power_spectrum = map(
lambda z: self.domain.calc_power(z, codomain=self.codomain, lambda z: self.domain.calc_power(z, codomain=self.codomain,
**kwargs), **kwargs),
self.get_val()) self.get_val())
...@@ -918,7 +919,7 @@ class field(object): ...@@ -918,7 +919,7 @@ class field(object):
The new diagonal operator instance. The new diagonal operator instance.
""" """
any_zero_Q = self._map(lambda z: (z == 0).any(), self.get_val()) any_zero_Q = map(lambda z: (z == 0).any(), self.get_val())
any_zero_Q = np.any(any_zero_Q) any_zero_Q = np.any(any_zero_Q)
if any_zero_Q: if any_zero_Q:
raise AttributeError( raise AttributeError(
...@@ -1020,7 +1021,7 @@ class field(object): ...@@ -1020,7 +1021,7 @@ class field(object):
"\n- ishape = " + str(self.ishape) "\n- ishape = " + str(self.ishape)
def _unary_helper(self, x, op, **kwargs): def _unary_helper(self, x, op, **kwargs):
result = self._map( result = map(
lambda z: self.domain.unary_operation(z, op=op, **kwargs), lambda z: self.domain.unary_operation(z, op=op, **kwargs),
self.get_val()) self.get_val())
return result return result
...@@ -1237,7 +1238,7 @@ class field(object): ...@@ -1237,7 +1238,7 @@ class field(object):
else: else:
other_val = self._cast_to_tensor_helper(other_val) other_val = self._cast_to_tensor_helper(other_val)
new_val = self._map( new_val = map(
lambda z1, z2: self.domain.binary_operation(z1, z2, op=op, cast=0), lambda z1, z2: self.domain.binary_operation(z1, z2, op=op, cast=0),
self.get_val(), self.get_val(),
other_val) other_val)
......
...@@ -427,9 +427,8 @@ class power_indices(object): ...@@ -427,9 +427,8 @@ class power_indices(object):
class rg_power_indices(power_indices): class rg_power_indices(power_indices):
def __init__(self, shape, dgrid, datamodel, def __init__(self, shape, dgrid, allowed_distribution_strategies,
allowed_distribution_strategies, datamodel='not', zerocentered=False, log=False, nbin=None,
zerocentered=False, log=False, nbin=None,
binbounds=None, comm=None): binbounds=None, comm=None):
""" """
Returns an instance of the power_indices class. Given the shape and Returns an instance of the power_indices class. Given the shape and
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from numpy.testing import assert_equal,\ from numpy.testing import assert_equal, \
assert_almost_equal,\ assert_almost_equal, \
assert_raises assert_raises
from nose_parameterized import parameterized from nose_parameterized import parameterized
...@@ -9,20 +9,20 @@ import unittest ...@@ -9,20 +9,20 @@ import unittest
import itertools import itertools
import numpy as np import numpy as np
from nifty import space,\ from nifty import space, \
point_space,\ point_space, \
rg_space,\ rg_space, \
lm_space,\ lm_space, \
hp_space,\ hp_space, \
gl_space gl_space
from nifty.nifty_field import field from nifty.nifty_field import field
from nifty.nifty_core import POINT_DISTRIBUTION_STRATEGIES from nifty.nifty_core import POINT_DISTRIBUTION_STRATEGIES
from nifty.rg.nifty_rg import RG_DISTRIBUTION_STRATEGIES,\ from nifty.rg.nifty_rg import RG_DISTRIBUTION_STRATEGIES, \
gc as RG_GC gc as RG_GC
from nifty.lm.nifty_lm import LM_DISTRIBUTION_STRATEGIES,\ from nifty.lm.nifty_lm import LM_DISTRIBUTION_STRATEGIES, \
GL_DISTRIBUTION_STRATEGIES,\ GL_DISTRIBUTION_STRATEGIES, \
HP_DISTRIBUTION_STRATEGIES HP_DISTRIBUTION_STRATEGIES
...@@ -34,6 +34,7 @@ def custom_name_func(testcase_func, param_num, param): ...@@ -34,6 +34,7 @@ def custom_name_func(testcase_func, param_num, param):
parameterized.to_safe_name("_".join(str(x) for x in param.args)), parameterized.to_safe_name("_".join(str(x) for x in param.args)),
) )
############################################################################### ###############################################################################
############################################################################### ###############################################################################
...@@ -92,62 +93,33 @@ for param in itertools.product([(1,), (4, 6), (5, 8)], ...@@ -92,62 +93,33 @@ for param in itertools.product([(1,), (4, 6), (5, 8)],
[False], [False],
DATAMODELS['rg_space'], DATAMODELS['rg_space'],
fft_modules): fft_modules):
space_list += [[(rg_space(shape=param[0], space_list += [[rg_space(shape=param[0],
zerocenter=param[1], zerocenter=param[1],
complexity=param[2], complexity=param[2],
distances=param[3], distances=param[3],
harmonic=param[4], harmonic=param[4],
fft_module=param[6]),param[6])]] fft_module=param[6]), param[5]]]
############################################################################### ###############################################################################
############################################################################### ###############################################################################
class Test_field_init(unittest.TestCase): class Test_field_init(unittest.TestCase):
@parameterized.expand(
@parameterized.expand(space_list) itertools.product([(1,), (4, 6), (5, 8)],
def test_successfull_init_and_attributes(self, s, datamodel): [False, True],
f = field(domain=np.array([s]), dtype=s.dtype, datamodel=datamodel) [0, 1, 2],
assert(f.domain[0] is s) [None, 0.3],
assert(s.check_codomain(f.codomain[0])) [False],
fft_modules,
DATAMODELS['rg_space']),
testcase_func_name=custom_name_func)
def test_successfull_init_and_attributes(self, shape, zerocenter,
complexity, distances, harmonic,
fft_module, datamodel):
s = rg_space(shape=shape, zerocenter=zerocenter,
complexity=complexity, distances=distances,
harmonic=harmonic, fft_module=fft_module)
f = field(domain=(s,), dtype=s.dtype, datamodel=datamodel)
assert (f.domain[0] is s)
assert (s.check_codomain(f.codomain[0]))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment