Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
ift
NIFTy
Commits
080116da
Commit
080116da
authored
Jun 20, 2016
by
csongor
Browse files
Fixed map-ings in nifty_field
parent
89729ee5
Changes
2
Hide whitespace changes
Inline
Side-by-side
nifty_core.py
View file @
080116da
...
...
@@ -1203,8 +1203,6 @@ class point_space(space):
dot : scalar
Inner product of the two arrays.
"""
x
=
self
.
cast
(
x
)
y
=
self
.
cast
(
y
)
result
=
x
.
vdot
(
y
)
...
...
nifty_field.py
View file @
080116da
...
...
@@ -214,10 +214,9 @@ class field(object):
if
kwargs
==
{}:
val
=
self
.
cast
(
0
)
else
:
val
=
map
(
lambda
z
:
self
.
get_random_values
(
domain
=
self
.
domain
,
codomain
=
z
,
**
kwargs
),
self
.
codomain
)
val
=
self
.
get_random_values
(
domain
=
self
.
domain
,
codomain
=
self
.
codomain
,
**
kwargs
)
self
.
set_val
(
new_val
=
val
,
copy
=
copy
)
def
_get_dtype_from_domain
(
self
,
domain
=
None
):
...
...
@@ -290,30 +289,15 @@ class field(object):
self
.
codomain
=
codomain
return
codomain
def
get_random_values
(
self
,
**
kwargs
):
def
get_random_values
(
self
,
domain
=
None
,
codomain
=
None
,
**
kwargs
):
raise
NotImplementedError
(
about
.
_errors
.
cstring
(
"ERROR: no generic instance method 'enforce_power'."
))
def
__len__
(
self
):
return
int
(
self
.
get_dim
()[
0
])
def
apply_scalar_function
(
self
,
function
,
inplace
=
False
):
if
inplace
:
working_field
=
self
else
:
working_field
=
self
.
copy_empty
()
data_object
=
map
(
lambda
z
:
self
.
domain
.
apply_scalar_function
(
z
,
function
,
inplace
),
self
.
get_val
())
working_field
.
set_val
(
data_object
)
return
working_field
def
copy
(
self
,
domain
=
None
,
codomain
=
None
):
copied_val
=
map
(
lambda
z
:
self
.
domain
.
unary_operation
(
z
,
op
=
'copy'
),
self
.
get_val
())
def
copy
(
self
,
domain
=
None
,
codomain
=
None
,
**
kwargs
):
copied_val
=
self
.
_unary_operation
(
self
.
get_val
(),
op
=
'copy'
,
**
kwargs
)
new_field
=
self
.
copy_empty
(
domain
=
domain
,
codomain
=
codomain
)
new_field
.
set_val
(
new_val
=
copied_val
)
return
new_field
...
...
@@ -369,9 +353,7 @@ class field(object):
"""
if
new_val
is
not
None
:
if
copy
:
new_val
=
map
(
lambda
z
:
self
.
unary_operation
(
z
,
'copy'
),
new_val
)
new_val
=
self
.
unary_operation
(
new_val
,
op
=
'copy'
)
self
.
val
=
self
.
cast
(
new_val
)
return
self
.
val
...
...
@@ -424,9 +406,6 @@ class field(object):
else
:
return
dim
def
_map
(
self
,
function
,
*
args
):
return
utilities
.
field_map
(
self
.
get_shape
(),
function
,
*
args
)
def
cast
(
self
,
x
=
None
,
dtype
=
None
):
if
dtype
is
not
None
:
dtype
=
np
.
dtype
(
dtype
)
...
...
@@ -571,7 +550,7 @@ class field(object):
self
.
codomain
=
new_codomain
return
self
.
codomain
def
weight
(
self
,
power
=
1
,
overwrite
=
False
):
def
weight
(
self
,
new_val
=
None
,
power
=
1
,
overwrite
=
False
):
"""
Returns the field values, weighted with the volume factors to a
given power. The field values will optionally be overwritten.
...
...
@@ -597,8 +576,12 @@ class field(object):
else
:
new_field
=
self
.
copy_empty
()
new_val
=
map
(
lambda
y
:
self
.
domain
.
calc_weight
(
y
,
power
=
power
),
self
.
get_val
())
if
new_val
is
None
:
new_val
=
self
.
get_val
()
for
ind
,
space
in
self
.
domain
:
new_val
=
space
.
calc_weigth
(
new_val
,
power
=
power
,
axis
=
self
.
_axis_list
[
ind
])
new_field
.
set_val
(
new_val
=
new_val
)
return
new_field
...
...
@@ -662,20 +645,11 @@ class field(object):
# Case 3: x is something else
else
:
# Cast the input in order to cure dtype and shape differences
casted_x
=
self
.
_
cast
_to_ishape
(
x
)
casted_x
=
self
.
cast
(
x
)
# Compute the dot respecting the fact of discrete/continous spaces
if
self
.
domain
.
discrete
or
bare
:
result
=
map
(
lambda
z1
,
z2
:
self
.
domain
.
calc_dot
(
z1
,
z2
),
self
.
get_val
(),
casted_x
)
else
:
result
=
map
(
lambda
z1
,
z2
:
self
.
domain
.
calc_dot
(
self
.
domain
.
calc_weight
(
z1
,
power
=
1
),
z2
),
self
.
get_val
(),
casted_x
)
if
not
(
np
.
isreal
(
self
.
get_val
())
or
bare
):
casted_x
=
self
.
weight
(
casted_x
,
power
=
1
)
result
=
self
.
get_val
().
dot
(
casted_x
)
return
np
.
sum
(
result
,
axis
=
axis
)
def
vdot
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -734,9 +708,11 @@ class field(object):
else
:
work_field
=
self
.
copy_empty
()
new_val
=
map
(
lambda
z
:
self
.
domain
.
unary_operation
(
z
,
'conjugate'
),
self
.
get_val
())
new_val
=
self
.
get_val
()
for
ind
,
space
in
self
.
domain
:
new_val
=
space
.
unary_operation
(
new_val
,
op
=
'conjugate'
,
axis
=
self
.
_axis_list
[
ind
])
work_field
.
set_val
(
new_val
=
new_val
)
return
work_field
...
...
@@ -779,10 +755,10 @@ class field(object):
else
:
assert
(
new_domain
.
check_codomain
(
new_codomain
))
new_val
=
map
(
lambda
z
:
self
.
domain
.
calc_transform
(
z
,
codomain
=
new_domain
,
**
kwargs
),
self
.
get_val
()
)
new_val
=
self
.
get_val
()
for
ind
,
space
in
self
.
domain
:
new_val
=
space
.
calc_transform
(
new_val
,
codomain
=
new_domain
,
axis
=
self
.
_axis_list
[
ind
],
**
kwargs
)
if
overwrite
:
return_field
=
self
...
...
@@ -825,9 +801,10 @@ class field(object):
else
:
new_field
=
self
.
copy_empty
()
new_val
=
map
(
lambda
z
:
self
.
domain
.
calc_smooth
(
z
,
sigma
=
sigma
,
**
kwargs
),
self
.
get_val
())
new_val
=
self
.
get_val
()
for
ind
,
space
in
self
.
domain
:
new_val
=
space
.
calc_smooth
(
new_val
,
sigma
=
sigma
,
axis
=
self
.
_axis_list
[
ind
],
**
kwargs
)
new_field
.
set_val
(
new_val
=
new_val
)
return
new_field
...
...
@@ -875,10 +852,12 @@ class field(object):
kwargs
.
__delitem__
(
"codomain"
)
about
.
warnings
.
cprint
(
"WARNING: codomain was removed from kwargs."
)
power_spectrum
=
map
(
lambda
z
:
self
.
domain
.
calc_power
(
z
,
codomain
=
self
.
codomain
,
**
kwargs
),
self
.
get_val
())
power_spectrum
=
self
.
get_val
()
for
ind
,
space
in
self
.
domain
:
power_spectrum
=
space
.
calc_smooth
(
power_spectrum
,
codomain
=
self
.
codomain
,
axis
=
self
.
_axis_list
[
ind
],
**
kwargs
)
return
power_spectrum
...
...
@@ -908,8 +887,7 @@ class field(object):
The new diagonal operator instance.
"""
any_zero_Q
=
map
(
lambda
z
:
(
z
==
0
).
any
(),
self
.
get_val
())
any_zero_Q
=
np
.
any
(
any_zero_Q
)
any_zero_Q
=
np
.
any
(
map
(
lambda
z
:
(
z
==
0
),
self
.
get_val
()))
if
any_zero_Q
:
raise
AttributeError
(
about
.
_errors
.
cstring
(
"ERROR: singular operator."
))
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment