Commit 7714b6cb authored by csongor's avatar csongor
Browse files

WIP: fix field.__init__

parent 88ed077d
Pipeline #3932 skipped
......@@ -9,6 +9,8 @@ from nifty.config import about,\
nifty_configuration as gc,\
dependency_injector as gdi
from nifty.nifty_core import space
import nifty.nifty_utilities as utilities
POINT_DISTRIBUTION_STRATEGIES = DISTRIBUTION_STRATEGIES['global']
......@@ -202,11 +204,10 @@ class field(object):
if val is None:
if kwargs == {}:
val = self._map(lambda: self.cast((0,)))
val = map(lambda z: self.cast(z), (0,))
else:
val = self._map(lambda: self.domain.get_random_values(
codomain=self.codomain,
**kwargs))
val = map(lambda z: self.domain.get_random_values(
codomain=z, **kwargs), self.codomain)
self.set_val(new_val=val, copy=copy)
def _parse_comm(self, comm):
......@@ -229,7 +230,7 @@ class field(object):
return result_comm
def check_valid_domain(self, domain):
if not isinstance(domain, np.ndarray):
if not isinstance(domain, tuple):
raise TypeError(about._errors.cstring(
"ERROR: The given domain is not a list."))
for d in domain:
......@@ -245,15 +246,15 @@ class field(object):
def check_codomain(self, domain, codomain):
if codomain is None:
return False
if domain.shape == codomain.shape:
if len(domain) == len(codomain):
return np.all(map((lambda d, c: d._check_codomain(c)), domain,
codomain))
else:
return False
def get_codomain(self, domain):
if domain.shape == (1,):
return np.array(domain[0].get_codomain())
if len(domain) == 1:
return (domain[0].get_codomain(),)
else:
# TODO implement for multiple domain get_codomain need
# calc_transform
......@@ -268,7 +269,7 @@ class field(object):
else:
working_field = self.copy_empty()
data_object = self._map(
data_object = map(
lambda z: self.domain.apply_scalar_function(z, function, inplace),
self.get_val())
......@@ -276,7 +277,7 @@ class field(object):
return working_field
def copy(self, domain=None, codomain=None):
copied_val = self._map(
copied_val = map(
lambda z: self.domain.unary_operation(z, op='copy'),
self.get_val())
new_field = self.copy_empty(domain=domain, codomain=codomain)
......@@ -334,10 +335,10 @@ class field(object):
"""
if new_val is not None:
if copy:
new_val = self._map(
lambda z: self.domain.unary_operation(z, 'copy'),
new_val = map(
lambda z: self.unary_operation(z, 'copy'),
new_val)
self.val = self._map(lambda z: self.domain.cast(z), new_val)
self.val = map(lambda z: self.cast(z), new_val)
return self.val
def get_val(self):
......@@ -359,7 +360,7 @@ class field(object):
if len(key) > len(self.ishape):
if is_data_container:
gotten = self._map(
gotten = map(
lambda z: self.domain.getitem(
z, key[len(self.ishape):]),
gotten)
......@@ -382,7 +383,7 @@ class field(object):
is_data_container = False
if is_data_container:
gotten = self._map(
gotten = map(
lambda z1, z2: self.domain.setitem(
z1, z2, key[len(self.ishape):]),
gotten, value)
......@@ -607,7 +608,7 @@ class field(object):
else:
new_field = self.copy_empty()
new_val = self._map(lambda y: self.domain.calc_weight(y, power=power),
new_val = map(lambda y: self.domain.calc_weight(y, power=power),
self.get_val())
new_field.set_val(new_val=new_val)
......@@ -675,12 +676,12 @@ class field(object):
casted_x = self._cast_to_ishape(x)
# Compute the dot respecting the fact of discrete/continous spaces
if self.domain.discrete or bare:
result = self._map(
result = map(
lambda z1, z2: self.domain.calc_dot(z1, z2),
self.get_val(),
casted_x)
else:
result = self._map(
result = map(
lambda z1, z2: self.domain.calc_dot(
self.domain.calc_weight(z1, power=1),
z2),
......@@ -744,7 +745,7 @@ class field(object):
else:
work_field = self.copy_empty()
new_val = self._map(
new_val = map(
lambda z: self.domain.unary_operation(z, 'conjugate'),
self.get_val())
work_field.set_val(new_val=new_val)
......@@ -789,7 +790,7 @@ class field(object):
else:
assert(new_domain.check_codomain(new_codomain))
new_val = self._map(
new_val = map(
lambda z: self.domain.calc_transform(
z, codomain=new_domain, **kwargs),
self.get_val())
......@@ -835,7 +836,7 @@ class field(object):
else:
new_field = self.copy_empty()
new_val = self._map(
new_val = map(
lambda z: self.domain.calc_smooth(z, sigma=sigma, **kwargs),
self.get_val())
......@@ -885,7 +886,7 @@ class field(object):
kwargs.__delitem__("codomain")
about.warnings.cprint("WARNING: codomain was removed from kwargs.")
power_spectrum = self._map(
power_spectrum = map(
lambda z: self.domain.calc_power(z, codomain=self.codomain,
**kwargs),
self.get_val())
......@@ -918,7 +919,7 @@ class field(object):
The new diagonal operator instance.
"""
any_zero_Q = self._map(lambda z: (z == 0).any(), self.get_val())
any_zero_Q = map(lambda z: (z == 0).any(), self.get_val())
any_zero_Q = np.any(any_zero_Q)
if any_zero_Q:
raise AttributeError(
......@@ -1020,7 +1021,7 @@ class field(object):
"\n- ishape = " + str(self.ishape)
def _unary_helper(self, x, op, **kwargs):
result = self._map(
result = map(
lambda z: self.domain.unary_operation(z, op=op, **kwargs),
self.get_val())
return result
......@@ -1237,7 +1238,7 @@ class field(object):
else:
other_val = self._cast_to_tensor_helper(other_val)
new_val = self._map(
new_val = map(
lambda z1, z2: self.domain.binary_operation(z1, z2, op=op, cast=0),
self.get_val(),
other_val)
......
......@@ -427,9 +427,8 @@ class power_indices(object):
class rg_power_indices(power_indices):
def __init__(self, shape, dgrid, datamodel,
allowed_distribution_strategies,
zerocentered=False, log=False, nbin=None,
def __init__(self, shape, dgrid, allowed_distribution_strategies,
datamodel='not', zerocentered=False, log=False, nbin=None,
binbounds=None, comm=None):
"""
Returns an instance of the power_indices class. Given the shape and
......
# -*- coding: utf-8 -*-
from numpy.testing import assert_equal,\
assert_almost_equal,\
from numpy.testing import assert_equal, \
assert_almost_equal, \
assert_raises
from nose_parameterized import parameterized
......@@ -9,21 +9,21 @@ import unittest
import itertools
import numpy as np
from nifty import space,\
point_space,\
rg_space,\
lm_space,\
hp_space,\
from nifty import space, \
point_space, \
rg_space, \
lm_space, \
hp_space, \
gl_space
from nifty.nifty_field import field
from nifty.nifty_core import POINT_DISTRIBUTION_STRATEGIES
from nifty.rg.nifty_rg import RG_DISTRIBUTION_STRATEGIES,\
gc as RG_GC
from nifty.lm.nifty_lm import LM_DISTRIBUTION_STRATEGIES,\
GL_DISTRIBUTION_STRATEGIES,\
HP_DISTRIBUTION_STRATEGIES
from nifty.rg.nifty_rg import RG_DISTRIBUTION_STRATEGIES, \
gc as RG_GC
from nifty.lm.nifty_lm import LM_DISTRIBUTION_STRATEGIES, \
GL_DISTRIBUTION_STRATEGIES, \
HP_DISTRIBUTION_STRATEGIES
###############################################################################
......@@ -34,6 +34,7 @@ def custom_name_func(testcase_func, param_num, param):
parameterized.to_safe_name("_".join(str(x) for x in param.args)),
)
###############################################################################
###############################################################################
......@@ -92,62 +93,33 @@ for param in itertools.product([(1,), (4, 6), (5, 8)],
[False],
DATAMODELS['rg_space'],
fft_modules):
space_list += [[(rg_space(shape=param[0],
zerocenter=param[1],
complexity=param[2],
distances=param[3],
harmonic=param[4],
fft_module=param[6]),param[6])]]
space_list += [[rg_space(shape=param[0],
zerocenter=param[1],
complexity=param[2],
distances=param[3],
harmonic=param[4],
fft_module=param[6]), param[5]]]
###############################################################################
###############################################################################
class Test_field_init(unittest.TestCase):
@parameterized.expand(space_list)
def test_successfull_init_and_attributes(self, s, datamodel):
f = field(domain=np.array([s]), dtype=s.dtype, datamodel=datamodel)
assert(f.domain[0] is s)
assert(s.check_codomain(f.codomain[0]))
@parameterized.expand(
itertools.product([(1,), (4, 6), (5, 8)],
[False, True],
[0, 1, 2],
[None, 0.3],
[False],
fft_modules,
DATAMODELS['rg_space']),
testcase_func_name=custom_name_func)
def test_successfull_init_and_attributes(self, shape, zerocenter,
complexity, distances, harmonic,
fft_module, datamodel):
s = rg_space(shape=shape, zerocenter=zerocenter,
complexity=complexity, distances=distances,
harmonic=harmonic, fft_module=fft_module)
f = field(domain=(s,), dtype=s.dtype, datamodel=datamodel)
assert (f.domain[0] is s)
assert (s.check_codomain(f.codomain[0]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment