Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
7714b6cb
Commit
7714b6cb
authored
May 29, 2016
by
csongor
Browse files
WIP: fix field.__init__
parent
88ed077d
Pipeline
#3932
skipped
Changes
3
Pipelines
1
Show whitespace changes
Inline
Side-by-side
nifty_field.py
View file @
7714b6cb
...
...
@@ -9,6 +9,8 @@ from nifty.config import about,\
nifty_configuration
as
gc
,
\
dependency_injector
as
gdi
from
nifty.nifty_core
import
space
import
nifty.nifty_utilities
as
utilities
POINT_DISTRIBUTION_STRATEGIES
=
DISTRIBUTION_STRATEGIES
[
'global'
]
...
...
@@ -202,11 +204,10 @@ class field(object):
if
val
is
None
:
if
kwargs
==
{}:
val
=
self
.
_
map
(
lambda
:
self
.
cast
((
0
,))
)
val
=
map
(
lambda
z
:
self
.
cast
(
z
),
(
0
,))
else
:
val
=
self
.
_map
(
lambda
:
self
.
domain
.
get_random_values
(
codomain
=
self
.
codomain
,
**
kwargs
))
val
=
map
(
lambda
z
:
self
.
domain
.
get_random_values
(
codomain
=
z
,
**
kwargs
),
self
.
codomain
)
self
.
set_val
(
new_val
=
val
,
copy
=
copy
)
def
_parse_comm
(
self
,
comm
):
...
...
@@ -229,7 +230,7 @@ class field(object):
return
result_comm
def
check_valid_domain
(
self
,
domain
):
if
not
isinstance
(
domain
,
np
.
ndarray
):
if
not
isinstance
(
domain
,
tuple
):
raise
TypeError
(
about
.
_errors
.
cstring
(
"ERROR: The given domain is not a list."
))
for
d
in
domain
:
...
...
@@ -245,15 +246,15 @@ class field(object):
def
check_codomain
(
self
,
domain
,
codomain
):
if
codomain
is
None
:
return
False
if
domain
.
shape
==
codomain
.
shape
:
if
len
(
domain
)
==
len
(
codomain
)
:
return
np
.
all
(
map
((
lambda
d
,
c
:
d
.
_check_codomain
(
c
)),
domain
,
codomain
))
else
:
return
False
def
get_codomain
(
self
,
domain
):
if
domain
.
shape
==
(
1
,)
:
return
np
.
array
(
domain
[
0
].
get_codomain
())
if
len
(
domain
)
==
1
:
return
(
domain
[
0
].
get_codomain
()
,
)
else
:
# TODO implement for multiple domain get_codomain need
# calc_transform
...
...
@@ -268,7 +269,7 @@ class field(object):
else
:
working_field
=
self
.
copy_empty
()
data_object
=
self
.
_
map
(
data_object
=
map
(
lambda
z
:
self
.
domain
.
apply_scalar_function
(
z
,
function
,
inplace
),
self
.
get_val
())
...
...
@@ -276,7 +277,7 @@ class field(object):
return
working_field
def
copy
(
self
,
domain
=
None
,
codomain
=
None
):
copied_val
=
self
.
_
map
(
copied_val
=
map
(
lambda
z
:
self
.
domain
.
unary_operation
(
z
,
op
=
'copy'
),
self
.
get_val
())
new_field
=
self
.
copy_empty
(
domain
=
domain
,
codomain
=
codomain
)
...
...
@@ -334,10 +335,10 @@ class field(object):
"""
if
new_val
is
not
None
:
if
copy
:
new_val
=
self
.
_
map
(
lambda
z
:
self
.
domain
.
unary_operation
(
z
,
'copy'
),
new_val
=
map
(
lambda
z
:
self
.
unary_operation
(
z
,
'copy'
),
new_val
)
self
.
val
=
self
.
_
map
(
lambda
z
:
self
.
domain
.
cast
(
z
),
new_val
)
self
.
val
=
map
(
lambda
z
:
self
.
cast
(
z
),
new_val
)
return
self
.
val
def
get_val
(
self
):
...
...
@@ -359,7 +360,7 @@ class field(object):
if
len
(
key
)
>
len
(
self
.
ishape
):
if
is_data_container
:
gotten
=
self
.
_
map
(
gotten
=
map
(
lambda
z
:
self
.
domain
.
getitem
(
z
,
key
[
len
(
self
.
ishape
):]),
gotten
)
...
...
@@ -382,7 +383,7 @@ class field(object):
is_data_container
=
False
if
is_data_container
:
gotten
=
self
.
_
map
(
gotten
=
map
(
lambda
z1
,
z2
:
self
.
domain
.
setitem
(
z1
,
z2
,
key
[
len
(
self
.
ishape
):]),
gotten
,
value
)
...
...
@@ -607,7 +608,7 @@ class field(object):
else
:
new_field
=
self
.
copy_empty
()
new_val
=
self
.
_
map
(
lambda
y
:
self
.
domain
.
calc_weight
(
y
,
power
=
power
),
new_val
=
map
(
lambda
y
:
self
.
domain
.
calc_weight
(
y
,
power
=
power
),
self
.
get_val
())
new_field
.
set_val
(
new_val
=
new_val
)
...
...
@@ -675,12 +676,12 @@ class field(object):
casted_x
=
self
.
_cast_to_ishape
(
x
)
# Compute the dot respecting the fact of discrete/continous spaces
if
self
.
domain
.
discrete
or
bare
:
result
=
self
.
_
map
(
result
=
map
(
lambda
z1
,
z2
:
self
.
domain
.
calc_dot
(
z1
,
z2
),
self
.
get_val
(),
casted_x
)
else
:
result
=
self
.
_
map
(
result
=
map
(
lambda
z1
,
z2
:
self
.
domain
.
calc_dot
(
self
.
domain
.
calc_weight
(
z1
,
power
=
1
),
z2
),
...
...
@@ -744,7 +745,7 @@ class field(object):
else
:
work_field
=
self
.
copy_empty
()
new_val
=
self
.
_
map
(
new_val
=
map
(
lambda
z
:
self
.
domain
.
unary_operation
(
z
,
'conjugate'
),
self
.
get_val
())
work_field
.
set_val
(
new_val
=
new_val
)
...
...
@@ -789,7 +790,7 @@ class field(object):
else
:
assert
(
new_domain
.
check_codomain
(
new_codomain
))
new_val
=
self
.
_
map
(
new_val
=
map
(
lambda
z
:
self
.
domain
.
calc_transform
(
z
,
codomain
=
new_domain
,
**
kwargs
),
self
.
get_val
())
...
...
@@ -835,7 +836,7 @@ class field(object):
else
:
new_field
=
self
.
copy_empty
()
new_val
=
self
.
_
map
(
new_val
=
map
(
lambda
z
:
self
.
domain
.
calc_smooth
(
z
,
sigma
=
sigma
,
**
kwargs
),
self
.
get_val
())
...
...
@@ -885,7 +886,7 @@ class field(object):
kwargs
.
__delitem__
(
"codomain"
)
about
.
warnings
.
cprint
(
"WARNING: codomain was removed from kwargs."
)
power_spectrum
=
self
.
_
map
(
power_spectrum
=
map
(
lambda
z
:
self
.
domain
.
calc_power
(
z
,
codomain
=
self
.
codomain
,
**
kwargs
),
self
.
get_val
())
...
...
@@ -918,7 +919,7 @@ class field(object):
The new diagonal operator instance.
"""
any_zero_Q
=
self
.
_
map
(
lambda
z
:
(
z
==
0
).
any
(),
self
.
get_val
())
any_zero_Q
=
map
(
lambda
z
:
(
z
==
0
).
any
(),
self
.
get_val
())
any_zero_Q
=
np
.
any
(
any_zero_Q
)
if
any_zero_Q
:
raise
AttributeError
(
...
...
@@ -1020,7 +1021,7 @@ class field(object):
"
\n
- ishape = "
+
str
(
self
.
ishape
)
def
_unary_helper
(
self
,
x
,
op
,
**
kwargs
):
result
=
self
.
_
map
(
result
=
map
(
lambda
z
:
self
.
domain
.
unary_operation
(
z
,
op
=
op
,
**
kwargs
),
self
.
get_val
())
return
result
...
...
@@ -1237,7 +1238,7 @@ class field(object):
else
:
other_val
=
self
.
_cast_to_tensor_helper
(
other_val
)
new_val
=
self
.
_
map
(
new_val
=
map
(
lambda
z1
,
z2
:
self
.
domain
.
binary_operation
(
z1
,
z2
,
op
=
op
,
cast
=
0
),
self
.
get_val
(),
other_val
)
...
...
nifty_power_indices.py
View file @
7714b6cb
...
...
@@ -427,9 +427,8 @@ class power_indices(object):
class
rg_power_indices
(
power_indices
):
def
__init__
(
self
,
shape
,
dgrid
,
datamodel
,
allowed_distribution_strategies
,
zerocentered
=
False
,
log
=
False
,
nbin
=
None
,
def
__init__
(
self
,
shape
,
dgrid
,
allowed_distribution_strategies
,
datamodel
=
'not'
,
zerocentered
=
False
,
log
=
False
,
nbin
=
None
,
binbounds
=
None
,
comm
=
None
):
"""
Returns an instance of the power_indices class. Given the shape and
...
...
test/test_nifty_field.py
View file @
7714b6cb
# -*- coding: utf-8 -*-
from
numpy.testing
import
assert_equal
,
\
assert_almost_equal
,
\
from
numpy.testing
import
assert_equal
,
\
assert_almost_equal
,
\
assert_raises
from
nose_parameterized
import
parameterized
...
...
@@ -9,20 +9,20 @@ import unittest
import
itertools
import
numpy
as
np
from
nifty
import
space
,
\
point_space
,
\
rg_space
,
\
lm_space
,
\
hp_space
,
\
from
nifty
import
space
,
\
point_space
,
\
rg_space
,
\
lm_space
,
\
hp_space
,
\
gl_space
from
nifty.nifty_field
import
field
from
nifty.nifty_core
import
POINT_DISTRIBUTION_STRATEGIES
from
nifty.rg.nifty_rg
import
RG_DISTRIBUTION_STRATEGIES
,
\
from
nifty.rg.nifty_rg
import
RG_DISTRIBUTION_STRATEGIES
,
\
gc
as
RG_GC
from
nifty.lm.nifty_lm
import
LM_DISTRIBUTION_STRATEGIES
,
\
GL_DISTRIBUTION_STRATEGIES
,
\
from
nifty.lm.nifty_lm
import
LM_DISTRIBUTION_STRATEGIES
,
\
GL_DISTRIBUTION_STRATEGIES
,
\
HP_DISTRIBUTION_STRATEGIES
...
...
@@ -34,6 +34,7 @@ def custom_name_func(testcase_func, param_num, param):
parameterized
.
to_safe_name
(
"_"
.
join
(
str
(
x
)
for
x
in
param
.
args
)),
)
###############################################################################
###############################################################################
...
...
@@ -92,62 +93,33 @@ for param in itertools.product([(1,), (4, 6), (5, 8)],
[
False
],
DATAMODELS
[
'rg_space'
],
fft_modules
):
space_list
+=
[[
(
rg_space
(
shape
=
param
[
0
],
space_list
+=
[[
rg_space
(
shape
=
param
[
0
],
zerocenter
=
param
[
1
],
complexity
=
param
[
2
],
distances
=
param
[
3
],
harmonic
=
param
[
4
],
fft_module
=
param
[
6
]),
param
[
6
])
]]
fft_module
=
param
[
6
]),
param
[
5
]
]]
###############################################################################
###############################################################################
class
Test_field_init
(
unittest
.
TestCase
):
@
parameterized
.
expand
(
space_list
)
def
test_successfull_init_and_attributes
(
self
,
s
,
datamodel
):
f
=
field
(
domain
=
np
.
array
([
s
]),
dtype
=
s
.
dtype
,
datamodel
=
datamodel
)
assert
(
f
.
domain
[
0
]
is
s
)
assert
(
s
.
check_codomain
(
f
.
codomain
[
0
]))
@
parameterized
.
expand
(
itertools
.
product
([(
1
,),
(
4
,
6
),
(
5
,
8
)],
[
False
,
True
],
[
0
,
1
,
2
],
[
None
,
0.3
],
[
False
],
fft_modules
,
DATAMODELS
[
'rg_space'
]),
testcase_func_name
=
custom_name_func
)
def
test_successfull_init_and_attributes
(
self
,
shape
,
zerocenter
,
complexity
,
distances
,
harmonic
,
fft_module
,
datamodel
):
s
=
rg_space
(
shape
=
shape
,
zerocenter
=
zerocenter
,
complexity
=
complexity
,
distances
=
distances
,
harmonic
=
harmonic
,
fft_module
=
fft_module
)
f
=
field
(
domain
=
(
s
,),
dtype
=
s
.
dtype
,
datamodel
=
datamodel
)
assert
(
f
.
domain
[
0
]
is
s
)
assert
(
s
.
check_codomain
(
f
.
codomain
[
0
]))
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment